From 4b28686721b1f3c4f4a7865e858d11cd78e0c8af Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 2 Dec 2024 07:16:08 -0800 Subject: [PATCH 01/10] Added Initial Implementation of the Agent Search Graph --- .../agent_search/primary_graph/edges.py | 34 ++ .../primary_graph/graph_builder.py | 116 +++++ .../agent_search/primary_graph/nodes.py | 462 ++++++++++++++++++ .../agent_search/primary_graph/states.py | 61 +++ backend/danswer/agent_search/run_graph.py | 21 + .../shared_graph_utils/prompts.py | 122 +++++ .../agent_search/shared_graph_utils/utils.py | 15 + .../danswer/agent_search/subgraph/edges.py | 25 + .../agent_search/subgraph/graph_builder.py | 59 +++ .../danswer/agent_search/subgraph/nodes.py | 316 ++++++++++++ .../danswer/agent_search/subgraph/states.py | 71 +++ backend/requirements/default.txt | 13 +- 12 files changed, 1312 insertions(+), 3 deletions(-) create mode 100644 backend/danswer/agent_search/primary_graph/edges.py create mode 100644 backend/danswer/agent_search/primary_graph/graph_builder.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes.py create mode 100644 backend/danswer/agent_search/primary_graph/states.py create mode 100644 backend/danswer/agent_search/run_graph.py create mode 100644 backend/danswer/agent_search/shared_graph_utils/prompts.py create mode 100644 backend/danswer/agent_search/shared_graph_utils/utils.py create mode 100644 backend/danswer/agent_search/subgraph/edges.py create mode 100644 backend/danswer/agent_search/subgraph/graph_builder.py create mode 100644 backend/danswer/agent_search/subgraph/nodes.py create mode 100644 backend/danswer/agent_search/subgraph/states.py diff --git a/backend/danswer/agent_search/primary_graph/edges.py b/backend/danswer/agent_search/primary_graph/edges.py new file mode 100644 index 00000000000..577136b489b --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/edges.py @@ -0,0 +1,34 @@ +from langgraph.types import Send + +from danswer.agent_search.primary_graph.states import QAState + + +def continue_to_retrieval(state: QAState) -> list[Send]: + # Routes re-written queries to the (parallel) retrieval steps + # Notice the 'Send()' API that takes care of the parallelization + return [ + Send("custom_retrieve", {"query": query}) + for query in state["rewritten_queries"] + ] + + +def continue_to_answer_sub_questions(state: QAState) -> list[Send]: + # Routes re-written queries to the (parallel) retrieval steps + # Notice the 'Send()' API that takes care of the parallelization + return [ + Send("sub_answers_graph", {"base_answer_sub_question": sub_question}) + for sub_question in state["sub_questions"] + ] + + +def continue_to_verifier(state: QAState) -> list[Send]: + # Routes each de-douped retrieved doc to the verifier step - in parallel + # Notice the 'Send()' API that takes care of the parallelization + + return [ + Send( + "verifier", + {"document": doc, "original_question": state["original_question"]}, + ) + for doc in state["deduped_retrieval_docs"] + ] diff --git a/backend/danswer/agent_search/primary_graph/graph_builder.py b/backend/danswer/agent_search/primary_graph/graph_builder.py new file mode 100644 index 00000000000..c429af17de1 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/graph_builder.py @@ -0,0 +1,116 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions +from danswer.agent_search.primary_graph.edges import continue_to_retrieval +from danswer.agent_search.primary_graph.edges import continue_to_verifier +from danswer.agent_search.primary_graph.nodes import base_check +from danswer.agent_search.primary_graph.nodes import combine_retrieved_docs +from danswer.agent_search.primary_graph.nodes import consolidate_sub_qa +from danswer.agent_search.primary_graph.nodes import custom_retrieve +from danswer.agent_search.primary_graph.nodes import decompose +from danswer.agent_search.primary_graph.nodes import deep_answer_generation +from danswer.agent_search.primary_graph.nodes import final_stuff +from danswer.agent_search.primary_graph.nodes import generate +from danswer.agent_search.primary_graph.nodes import rewrite +from danswer.agent_search.primary_graph.nodes import verifier +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.subgraph.graph_builder import build_subgraph + + +def build_core_graph() -> StateGraph: + # Define the nodes we will cycle between + coreAnswerGraph = StateGraph(QAState) + + # Re-writing the question + coreAnswerGraph.add_node(node="rewrite", action=rewrite) + + # The retrieval step + coreAnswerGraph.add_node(node="custom_retrieve", action=custom_retrieve) + + # Combine and dedupe retrieved docs. + coreAnswerGraph.add_node( + node="combine_retrieved_docs", action=combine_retrieved_docs + ) + + # Verifying that a retrieved doc is relevant + coreAnswerGraph.add_node(node="verifier", action=verifier) + + sub_answers_graph = build_subgraph() + # Answering a sub-question + coreAnswerGraph.add_node(node="sub_answers_graph", action=sub_answers_graph) + + # A final clean-up step + coreAnswerGraph.add_node(node="final_stuff", action=final_stuff) + + # Decomposing the question into sub-questions + coreAnswerGraph.add_node(node="decompose", action=decompose) + + # Checking whether the initial answer is in the ballpark + coreAnswerGraph.add_node(node="base_check", action=base_check) + + # Generating a response after we know the documents are relevant + coreAnswerGraph.add_node(node="generate", action=generate) + + # Consolidating the sub-questions and answers + coreAnswerGraph.add_node(node="consolidate_sub_qa", action=consolidate_sub_qa) + + # Generating a deep answer + coreAnswerGraph.add_node( + node="deep_answer_generation", action=deep_answer_generation + ) + + ### Edges ### + + # start by rewriting the prompt + coreAnswerGraph.add_edge(start_key=START, end_key="rewrite") + + # Kick off another flow to decompose the question into sub-questions + coreAnswerGraph.add_edge(start_key=START, end_key="decompose") + + coreAnswerGraph.add_conditional_edges( + source="rewrite", + path=continue_to_retrieval, + path_map={"custom_retrieve": "custom_retrieve"}, + ) + + # check whether answer addresses the question + coreAnswerGraph.add_edge( + start_key="custom_retrieve", end_key="combine_retrieved_docs" + ) + + coreAnswerGraph.add_conditional_edges( + source="combine_retrieved_docs", + path=continue_to_verifier, + path_map={"verifier": "verifier"}, + ) + + coreAnswerGraph.add_conditional_edges( + source="decompose", + path=continue_to_answer_sub_questions, + path_map={"sub_answers_graph": "sub_answers_graph"}, + ) + + # use the retrieved information to generate the answer + coreAnswerGraph.add_edge(start_key="verifier", end_key="generate") + + # check whether answer addresses the question + coreAnswerGraph.add_edge(start_key="generate", end_key="base_check") + + coreAnswerGraph.add_edge( + start_key="sub_answers_graph", end_key="consolidate_sub_qa" + ) + + coreAnswerGraph.add_edge( + start_key="consolidate_sub_qa", end_key="deep_answer_generation" + ) + + coreAnswerGraph.add_edge( + start_key=["base_check", "deep_answer_generation"], end_key="final_stuff" + ) + + coreAnswerGraph.add_edge(start_key="final_stuff", end_key=END) + coreAnswerGraph.compile() + + return coreAnswerGraph diff --git a/backend/danswer/agent_search/primary_graph/nodes.py b/backend/danswer/agent_search/primary_graph/nodes.py new file mode 100644 index 00000000000..19e8b38e5f7 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes.py @@ -0,0 +1,462 @@ +import datetime +import json +from typing import Any +from typing import Dict +from typing import Literal + +from langchain_core.messages import HumanMessage +from pydantic import BaseModel + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT +from danswer.agent_search.shared_graph_utils.prompts import DECOMPOSE_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace +from danswer.chat.models import DanswerContext +from danswer.llm.interfaces import LLM + + +# Pydantic models for structured outputs +class RewrittenQueries(BaseModel): + rewritten_queries: list[str] + + +class BinaryDecision(BaseModel): + decision: Literal["yes", "no"] + + +class SubQuestions(BaseModel): + sub_questions: list[str] + + +# Transform the initial question into more suitable search queries. +def rewrite(qa_state: QAState) -> Dict[str, Any]: + """ + Transform the initial question into more suitable search queries. + + Args: + qa_state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + print("---TRANSFORM QUERY---") + + start_time = datetime.datetime.now() + + question = qa_state["original_question"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + + # Get the rewritten queries in a defined format + llm: LLM = qa_state["llm"] + tools: list[dict] = qa_state["tools"] + response = list( + llm.stream( + prompt=msg, + tools=tools, + structured_response_format=RewrittenQueries.model_json_schema(), + ) + ) + + formatted_response: RewrittenQueries = json.loads(response[0].content) + + end_time = datetime.datetime.now() + return { + "rewritten_queries": formatted_response.rewritten_queries, + "log_messages": f"{str(start_time)} - {str(end_time)}: core - rewrite", + } + + +def custom_retrieve(retriever_state: RetrieverState) -> Dict[str, Any]: + """ + Retrieve documents + + Args: + retriever_state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE---") + + start_time = datetime.datetime.now() + + retriever_state["rewritten_query"] + + # Retrieval + # TODO: add the actual retrieval, probably from search_tool.run() + documents: list[DanswerContext] = [] + + end_time = datetime.datetime.now() + return { + "base_retrieval_docs": documents, + "log_messages": f"{str(start_time)} - {str(end_time)}: core - custom_retrieve", + } + + +def combine_retrieved_docs(qa_state: QAState) -> Dict[str, Any]: + """ + Dedupe the retrieved docs. + """ + start_time = datetime.datetime.now() + + base_retrieval_docs = qa_state["base_retrieval_docs"] + + print(f"Number of docs from steps: {len(base_retrieval_docs)}") + dedupe_docs = [] + for base_retrieval_doc in base_retrieval_docs: + if base_retrieval_doc not in dedupe_docs: + dedupe_docs.append(base_retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + end_time = datetime.datetime.now() + return { + "deduped_retrieval_docs": dedupe_docs, + "log_messages": f"{str(start_time)} - {str(end_time)}: core - combine_retrieved_docs (dedupe)", + } + + +def verifier(state: VerifierState) -> Dict[str, Any]: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---VERIFY QUTPUT---") + start_time = datetime.datetime.now() + + question = state["original_question"] + document_content = state["document"].content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + # Grader + llm: LLM = state["llm"] + tools: list[dict] = state["tools"] + response = list( + llm.stream( + prompt=msg, + tools=tools, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + formatted_response: BinaryDecision = response[0].content + + end_time = datetime.datetime.now() + if formatted_response.decision == "yes": + end_time = datetime.datetime.now() + return { + "deduped_retrieval_docs": [state["document"]], + "log_messages": f"{str(start_time)} - {str(end_time)}: core - verifier: yes", + } + else: + end_time = datetime.datetime.now() + return { + "deduped_retrieval_docs": [], + "log_messages": f"{str(start_time)} - {str(end_time)}: core - verifier: no", + } + + +def generate(qa_state: QAState) -> Dict[str, Any]: + """ + Generate answer + + Args: + qa_state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---GENERATE---") + start_time = datetime.datetime.now() + + question = qa_state["original_question"] + docs = qa_state["deduped_retrieval_docs"] + + print(f"Number of verified retrieval docs: {docs}") + + # LLM + llm: LLM = qa_state["llm"] + + # Chain + # rag_chain = BASE_RAG_PROMPT | llm | StrOutputParser() + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + ) + ] + + # Grader + llm: LLM = qa_state["llm"] + tools: list[dict] = qa_state["tools"] + response = list( + llm.stream( + prompt=msg, + tools=tools, + structured_response_format=None, + ) + ) + + # Run + # response = rag_chain.invoke({"context": docs, + # "question": question}) + + end_time = datetime.datetime.now() + return { + "base_answer": response[0].content, + "log_messages": f"{str(start_time)} - {str(end_time)}: core - generate", + } + + +def base_check(qa_state: QAState) -> Dict[str, Any]: + """ + Check whether the final output satisfies the original user question + + Args: + qa_state (messages): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---CHECK QUTPUT---") + start_time = datetime.datetime.now() + + # time.sleep(5) + + initial_base_answer = qa_state["initial_base_answer"] + + question = qa_state["original_question"] + + BASE_CHECK_MESSAGE = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format( + question=question, base_answer=initial_base_answer + ) + ) + ] + + llm: LLM = qa_state["llm"] + tools: list[dict] = qa_state["tools"] + response = list( + llm.stream( + prompt=BASE_CHECK_MESSAGE, + tools=tools, + structured_response_format=None, + ) + ) + + print(f"Verdict: {response[0].content}") + + end_time = datetime.datetime.now() + return { + "base_answer": initial_base_answer, + "log_messages": f"{str(start_time)} - {str(end_time)}: core - base_check", + } + + +def final_stuff(qa_state: QAState) -> Dict[str, Any]: + """ + Invokes the agent model to generate a response based on the current state. Given + the question, it will decide to retrieve using the retriever tool, or simply end. + + Args: + qa_state (messages): The current state + + Returns: + dict: The updated state with the agent response appended to messages + """ + print("---FINAL---") + start_time = datetime.datetime.now() + + messages = qa_state["log_messages"] + time_ordered_messages = [x.content for x in messages] + time_ordered_messages.sort() + + print("Message Log:") + print("\n".join(time_ordered_messages)) + + end_time = datetime.datetime.now() + print(f"{str(start_time)} - {str(end_time)}: all - final_stuff") + + print("--------------------------------") + + base_answer = qa_state["base_answer"] + deep_answer = qa_state["deep_answer"] + sub_qas = qa_state["checked_sub_qas"] + sub_qa_list = [] + for sub_qa in sub_qas: + sub_qa_list.append( + f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----' + ) + sub_qa_context = "\n".join(sub_qa_list) + + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Final Deep Answer:\n{deep_answer}") + print("--------------------------------") + print("Sub Questions and Answers:") + print(sub_qa_context) + + return { + "log_messages": f"{str(start_time)} - {str(end_time)}: all - final_stuff", + } + + +# nodes + + +def decompose(qa_state: QAState) -> Dict[str, Any]: + """ + Decompose a complex question into simpler sub-questions. + + Args: + qa_state: The current QA state containing the original question and LLM + + Returns: + Dict containing sub_questions and log messages + """ + + start_time = datetime.datetime.now() + + question = qa_state["original_question"] + + msg = [ + HumanMessage( + content=DECOMPOSE_PROMPT.format(question=question), + ) + ] + + # Grader + llm: LLM = qa_state["llm"] + tools: list[dict] = qa_state["tools"] + response = list( + llm.stream( + prompt=msg, + tools=tools, + structured_response_format=SubQuestions.model_json_schema(), + ) + ) + + formatted_response: SubQuestions = response[0].content + + end_time = datetime.datetime.now() + return { + "sub_questions": formatted_response.sub_questions, + "log_messages": f"{str(start_time)} - {str(end_time)}: deep - decompose", + } + + +# aggregate sub questions and answers +def consolidate_sub_qa(qa_state: QAState) -> Dict[str, Any]: + """ + Consolidate sub-questions and their answers. + + Args: + qa_state: The current QA state containing sub QAs + + Returns: + Dict containing dynamic context, checked sub QAs and log messages + """ + sub_qas = qa_state["sub_qas"] + + start_time = datetime.datetime.now() + + dynamic_context_list = [ + "Below you will find useful information to answer the original question:" + ] + checked_sub_qas = [] + + for sub_qa in sub_qas: + question = sub_qa["sub_question"] + answer = sub_qa["sub_answer"] + verified = sub_qa["sub_answer_check"] + + if verified == "yes": + dynamic_context_list.append( + f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n" + ) + checked_sub_qas.append({"sub_question": question, "sub_answer": answer}) + dynamic_context = "\n".join(dynamic_context_list) + + end_time = datetime.datetime.now() + return { + "dynamic_context": dynamic_context, + "checked_sub_qas": checked_sub_qas, + "log_messages": f"{str(start_time)} - {str(end_time)}: deep - consolidate_sub_qa", + } + + +# aggregate sub questions and answers +def deep_answer_generation(qa_state: QAState) -> Dict[str, Any]: + """ + Generate answer + + Args: + qa_state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---GENERATE---") + start_time = datetime.datetime.now() + + question = qa_state["original_question"] + docs = qa_state["deduped_retrieval_docs"] + + deep_answer_context = qa_state["dynamic_context"] + + print(f"Number of verified retrieval docs: {docs}") + + combined_context = normalize_whitespace( + COMBINED_CONTEXT.format( + deep_answer_context=deep_answer_context, formated_docs=format_docs(docs) + ) + ) + + msg = [ + HumanMessage( + content=MODIFIED_RAG_PROMPT.format( + question=question, combined_context=combined_context + ) + ) + ] + + # Grader + # LLM + model: LLM = qa_state["llm"] + response = model.invoke(msg) + + end_time = datetime.datetime.now() + return { + "final_deep_answer": response.content, + "log_messages": f"{str(start_time)} - {str(end_time)}: deep - deep_answer_generation", + } + # return {"log_messages": [response]} diff --git a/backend/danswer/agent_search/primary_graph/states.py b/backend/danswer/agent_search/primary_graph/states.py new file mode 100644 index 00000000000..4729d5f8f63 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/states.py @@ -0,0 +1,61 @@ +import operator +from collections.abc import Sequence +from typing import Annotated +from typing import List +from typing import TypedDict + +from langchain_core.messages import BaseMessage +from langgraph.graph.message import add_messages + +from danswer.chat.models import DanswerContext +from danswer.llm.interfaces import LLM + + +class QAState(TypedDict): + # The 'main' state of the answer graph + original_question: str + log_messages: Annotated[Sequence[BaseMessage], add_messages] + rewritten_queries: List[str] + sub_questions: List[str] + sub_qas: Annotated[Sequence[dict], operator.add] + checked_sub_qas: Annotated[Sequence[dict], operator.add] + base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + deduped_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + reranked_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + retrieved_entities_relationships: dict + questions_context: List[dict] + qa_level: int + top_chunks: List[DanswerContext] + sub_question_top_chunks: Annotated[Sequence[dict], operator.add] + dynamic_context: str + initial_base_answer: str + base_answer: str + deep_answer: str + llm: LLM + tools: list[dict] + + +class QAOuputState(TypedDict): + # The 'main' output state of the answer graph. Removes all the intermediate states + original_question: str + log_messages: Annotated[Sequence[BaseMessage], add_messages] + sub_questions: List[str] + sub_qas: Annotated[Sequence[dict], operator.add] + checked_sub_qas: Annotated[Sequence[dict], operator.add] + reranked_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + retrieved_entities_relationships: dict + top_chunks: List[DanswerContext] + sub_question_top_chunks: Annotated[Sequence[dict], operator.add] + base_answer: str + deep_answer: str + + +class RetrieverState(TypedDict): + # The state for the parallel Retrievers. They each need to see only one query + rewritten_query: str + + +class VerifierState(TypedDict): + # The state for the parallel verification step. Each node execution need to see only one question/doc pair + document: DanswerContext + original_question: str diff --git a/backend/danswer/agent_search/run_graph.py b/backend/danswer/agent_search/run_graph.py new file mode 100644 index 00000000000..bbee5ce0270 --- /dev/null +++ b/backend/danswer/agent_search/run_graph.py @@ -0,0 +1,21 @@ +from danswer.agent_search.primary_graph.graph_builder import build_core_graph +from danswer.llm.answering.answer import AnswerStream +from danswer.llm.interfaces import LLM +from danswer.tools.tool import Tool + + +def run_graph( + query: str, + llm: LLM, + tools: list[Tool], +) -> AnswerStream: + graph = build_core_graph() + + inputs = { + "original_question": query, + "messages": [], + "tools": tools, + "llm": llm, + } + output = graph.invoke(input=inputs) + yield from output diff --git a/backend/danswer/agent_search/shared_graph_utils/prompts.py b/backend/danswer/agent_search/shared_graph_utils/prompts.py new file mode 100644 index 00000000000..3cf27167ae3 --- /dev/null +++ b/backend/danswer/agent_search/shared_graph_utils/prompts.py @@ -0,0 +1,122 @@ +REWRITE_PROMPT_MULTI = """\n + Please convert an initial user question into a 2-3 more appropriate + search queries for retrievel from a document store.\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Formulate the query: +""" + +REWRITE_PROMPT_SINGLE = """\n + Please convert an initial user question into a more appropriate search + query for retrievel from a document store.\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Formulate the query: +""" + + +BASE_RAG_PROMPT = """\n + You are an assistant for question-answering tasks. + Use the following pieces of retrieved context to answer the question. + If you don't know the answer, just say that you don't know. + Use three sentences maximum and keep the answer concise. + \nQuestion: {question} + \nContext: {context} + \nAnswer: +""" + +MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. + Use the following pieces of retrieved context to answer the question. + If you don't know the answer, just say that you don't know. + Use three sentences maximum and keep the answer concise. + Pay also particular attention to the sub-questions and their answers, + at least it may enrich the answer. + \nQuestion: {question} + \nContext: {combined_context} + \nAnswer: +""" + +BASE_CHECK_PROMPT = """\n + Please check whether the suggested answer seems to address the original question. Please only answer with 'yes' or 'no'\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the proposed answer: + \n ------- \n + {base_answer} + \n ------- \n + Please answer with yes or no: +""" + + +VERIFIER_PROMPT = """\n + Please check whether the document seems to be relevant for the + answer of the original question. Please only answer with 'yes' or 'no'\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the document text: + \n ------- \n + {document_content} + \n ------- \n + Please answer with yes or no: +""" + +DECOMPOSE_PROMPT = """\n + For an initial user question, please generate at least two but not + more than 3 individual sub-questions whose answers would help\n + to answer the initial question. The individual questions should be + answerable by a good RAG system. So a good idea would be to\n + use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question. + + Guidelines: + - The sub-questions should be specific to the question and provide + richer context for the question, and or resolve ambiguities + - Each sub-question - when answered - should be relevant for the + answer to the original question + - The sub-questions MUST have the full context of the original + question so that it can be executed by a RAG system independently + without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows + through Paris?" + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Please generate the list of good, fully contextualized sub-questions. + Think through it step by step and then generate the list. +""" + + +#### Consolidations +COMBINED_CONTEXT = """ + ------- + Below you will find useful information to answer the original question. + First, you see a number of sub-questions with their answers. + This information should be considered to be more focussed and + somewhat more specific to the original question as it tries to + contextualized facts. + After that will see the documents that were considered to be + relevant to answer the original question. + + Here are the sub-questions and their answers: + \n\n {deep_answer_context} \n\n + \n\n Here are the documents that were considered to be relevant to + answer the original question: + \n\n {formated_docs} \n\n + ---------------- +""" diff --git a/backend/danswer/agent_search/shared_graph_utils/utils.py b/backend/danswer/agent_search/shared_graph_utils/utils.py new file mode 100644 index 00000000000..1a3383ccd80 --- /dev/null +++ b/backend/danswer/agent_search/shared_graph_utils/utils.py @@ -0,0 +1,15 @@ +from collections.abc import Sequence + +from danswer.chat.models import DanswerContext + + +def normalize_whitespace(text: str) -> str: + """Normalize whitespace in text to single spaces and strip leading/trailing whitespace.""" + import re + + return re.sub(r"\s+", " ", text.strip()) + + +# Post-processing +def format_docs(docs: Sequence[DanswerContext]) -> str: + return "\n\n".join(doc.content for doc in docs) diff --git a/backend/danswer/agent_search/subgraph/edges.py b/backend/danswer/agent_search/subgraph/edges.py new file mode 100644 index 00000000000..2d293a5a34d --- /dev/null +++ b/backend/danswer/agent_search/subgraph/edges.py @@ -0,0 +1,25 @@ +from langgraph.types import Send + +from danswer.agent_search.primary_graph.states import QAState + + +def sub_continue_to_verifier(qa_state: QAState) -> list[Send]: + # Routes each de-douped retrieved doc to the verifier step - in parallel + # Notice the 'Send()' API that takes care of the parallelization + + return [ + Send( + "verifier", + {"document": doc, "question": qa_state["sub_question"]}, + ) + for doc in qa_state["sub_question_base_retrieval_docs"] + ] + + +def sub_continue_to_retrieval(qa_state: QAState) -> list[Send]: + # Routes re-written queries to the (parallel) retrieval steps + # Notice the 'Send()' API that takes care of the parallelization + return [ + Send("sub_custom_retrieve", {"rewritten_query": query}) + for query in qa_state["sub_question_rewritten_queries"] + ] diff --git a/backend/danswer/agent_search/subgraph/graph_builder.py b/backend/danswer/agent_search/subgraph/graph_builder.py new file mode 100644 index 00000000000..94a5e976dbe --- /dev/null +++ b/backend/danswer/agent_search/subgraph/graph_builder.py @@ -0,0 +1,59 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.agent_search.subgraph.edges import sub_continue_to_retrieval +from danswer.agent_search.subgraph.edges import sub_continue_to_verifier +from danswer.agent_search.subgraph.nodes import sub_combine_retrieved_docs +from danswer.agent_search.subgraph.nodes import sub_custom_retrieve +from danswer.agent_search.subgraph.nodes import sub_final_format +from danswer.agent_search.subgraph.nodes import sub_generate +from danswer.agent_search.subgraph.nodes import sub_qa_check +from danswer.agent_search.subgraph.nodes import sub_rewrite +from danswer.agent_search.subgraph.nodes import verifier +from danswer.agent_search.subgraph.states import SubQAOutputState +from danswer.agent_search.subgraph.states import SubQAState + + +def build_subgraph() -> StateGraph: + sub_answers = StateGraph(SubQAState, output=SubQAOutputState) + sub_answers.add_node(node="sub_rewrite", action=sub_rewrite) + sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve) + sub_answers.add_node( + node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs + ) + sub_answers.add_node(node="verifier", action=verifier) + sub_answers.add_node(node="sub_generate", action=sub_generate) + sub_answers.add_node(node="sub_qa_check", action=sub_qa_check) + sub_answers.add_node(node="sub_final_format", action=sub_final_format) + + sub_answers.add_edge(START, "sub_rewrite") + + sub_answers.add_conditional_edges( + "sub_rewrite", sub_continue_to_retrieval, ["sub_custom_retrieve"] + ) + + sub_answers.add_edge("sub_custom_retrieve", "sub_combine_retrieved_docs") + + sub_answers.add_conditional_edges( + "sub_combine_retrieved_docs", sub_continue_to_verifier, ["verifier"] + ) + + sub_answers.add_edge("verifier", "sub_generate") + + sub_answers.add_edge("sub_generate", "sub_qa_check") + + sub_answers.add_edge("sub_qa_check", "sub_final_format") + + sub_answers.add_edge("sub_final_format", END) + sub_answers_graph = sub_answers.compile() + return sub_answers_graph + + +if __name__ == "__main__": + # TODO: add the actual question + inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"} + sub_answers_graph = build_subgraph() + output = sub_answers_graph.invoke(inputs) + print("\nOUTPUT:") + print(output) diff --git a/backend/danswer/agent_search/subgraph/nodes.py b/backend/danswer/agent_search/subgraph/nodes.py new file mode 100644 index 00000000000..27cf0243859 --- /dev/null +++ b/backend/danswer/agent_search/subgraph/nodes.py @@ -0,0 +1,316 @@ +import datetime +import json +from typing import Any +from typing import Dict +from typing import Literal + +from langchain_core.messages import HumanMessage +from pydantic import BaseModel + +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.subgraph.states import SubQAOutputState +from danswer.agent_search.subgraph.states import SubQAState +from danswer.chat.models import DanswerContext +from danswer.llm.interfaces import LLM + + +class BinaryDecision(BaseModel): + decision: Literal["yes", "no"] + + +# unused at this point. Kept from tutorial. But here the agent makes a routing decision +# not that much of an agent if the routing is static... +def sub_rewrite(sub_qa_state: SubQAState) -> Dict[str, Any]: + """ + Transform the initial question into more suitable search queries. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + print("---SUB TRANSFORM QUERY---") + + start_time = datetime.datetime.now() + + # messages = state["base_answer_messages"] + question = sub_qa_state["sub_question"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + print(msg) + + # Get the rewritten queries in a defined format + ##response = model.with_structured_output(RewrittenQuery).invoke(msg) + ##rewritten_query = response.base_answer_rewritten_query + + rewritten_queries = ["music hard to listen to", "Music that is not fun or pleasant"] + + end_time = datetime.datetime.now() + return { + "sub_question_rewritten_queries": rewritten_queries, + "log_messages": f"{str(start_time)} - {str(end_time)}: sub - rewrite", + } + + +# dummy node to report on state if needed +def sub_custom_retrieve(retriever_state: RetrieverState) -> Dict[str, Any]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE SUB---") + + start_time = datetime.datetime.now() + + retriever_state["rewritten_query"] + + # query = state["rewritten_query"] + + # Retrieval + # TODO: add the actual retrieval, probably from search_tool.run() + documents: list[DanswerContext] = [] + + end_time = datetime.datetime.now() + return { + "sub_question_base_retrieval_docs": documents, + "log_messages": f"{str(start_time)} - {str(end_time)}: sub - custom_retrieve", + } + + +def sub_combine_retrieved_docs(sub_qa_state: SubQAState) -> Dict[str, Any]: + """ + Dedupe the retrieved docs. + """ + start_time = datetime.datetime.now() + + sub_question_base_retrieval_docs = sub_qa_state["sub_question_base_retrieval_docs"] + + print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") + dedupe_docs = [] + for base_retrieval_doc in sub_question_base_retrieval_docs: + if base_retrieval_doc not in dedupe_docs: + dedupe_docs.append(base_retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + end_time = datetime.datetime.now() + return { + "sub_question_deduped_retrieval_docs": dedupe_docs, + "log_messages": f"{str(start_time)} - {str(end_time)}: base - combine_retrieved_docs (dedupe)", + } + + +def verifier(verifier_state: VerifierState) -> Dict[str, Any]: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---VERIFY QUTPUT---") + start_time = datetime.datetime.now() + + question = verifier_state["original_question"] + document_content = verifier_state["document"].content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + # Grader + llm: LLM = verifier_state["llm"] + tools: list[dict] = verifier_state["tools"] + response = list( + llm.stream( + prompt=msg, + tools=tools, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + formatted_response: BinaryDecision = json.loads(response[0].content) + verdict = formatted_response.decision + + print(f"Verdict: {verdict}") + + end_time = datetime.datetime.now() + if verdict == "yes": + end_time = datetime.datetime.now() + return { + "sub_question_verified_retrieval_docs": [verifier_state["document"]], + "log_messages": f"{str(start_time)} - {str(end_time)}: base - verifier: yes", + } + else: + end_time = datetime.datetime.now() + return { + "sub_question_verified_retrieval_docs": [], + "log_messages": f"{str(start_time)} - {str(end_time)}: base - verifier: no", + } + + +def sub_generate(sub_qa_state: SubQAState) -> Dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---GENERATE---") + start_time = datetime.datetime.now() + + question = sub_qa_state["sub_question"] + docs = sub_qa_state["sub_question_verified_retrieval_docs"] + + print(f"Number of verified retrieval docs: {docs}") + + # LLM + llm: LLM = sub_qa_state["llm"] + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + ) + ] + + # Grader + llm: LLM = sub_qa_state["llm"] + tools: list[dict] = sub_qa_state["tools"] + response = list( + llm.stream( + prompt=msg, + tools=tools, + structured_response_format=None, + ) + ) + + answer = response[0].content + + end_time = datetime.datetime.now() + return { + "sub_question_answer": answer, + "log_messages": f"{str(start_time)} - {str(end_time)}: base - generate", + } + + +def sub_base_check(sub_qa_state: SubQAState) -> Dict[str, Any]: + """ + Check whether the final output satisfies the original user question + + Args: + state (messages): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---CHECK QUTPUT---") + start_time = datetime.datetime.now() + + base_answer = sub_qa_state["core_answer_base_answer"] + + question = sub_qa_state["original_question"] + + BASE_CHECK_MESSAGE = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) + ) + ] + + llm: LLM = sub_qa_state["llm"] + tools: list[dict] = sub_qa_state["tools"] + response = list( + llm.stream( + prompt=BASE_CHECK_MESSAGE, + tools=tools, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + formatted_response: BinaryDecision = json.loads(response[0].content) + verdict = formatted_response.decision + + print(f"Verdict: {verdict}") + + end_time = datetime.datetime.now() + return { + "base_answer": base_answer, + "log_messages": f"{str(start_time)} - {str(end_time)}: core - base_check", + } + + +def sub_final_format(sub_qa_state: SubQAState) -> SubQAOutputState: + """ + Create the final output for the QA subgraph + """ + + print("---BASE FINAL FORMAT---") + datetime.datetime.now() + + return { + "sub_qas": [ + { + "sub_question": sub_qa_state["sub_question"], + "sub_answer": sub_qa_state["sub_question_answer"], + "sub_answer_check": sub_qa_state["sub_question_answer_check"], + } + ], + "log_messages": sub_qa_state["log_messages"], + } + + +def sub_qa_check(sub_qa_state: SubQAState) -> Dict[str, str]: + """ + Check if the sub-question answer is satisfactory. + + Args: + state: The current SubQAState containing the sub-question and its answer + + Returns: + Dict containing the check result and log message + """ + end_time = datetime.datetime.now() + + q = sub_qa_state["sub_question"] + a = sub_qa_state["sub_question_answer"] + + BASE_CHECK_MESSAGE = [ + HumanMessage(content=BASE_CHECK_PROMPT.format(question=q, base_answer=a)) + ] + + model: LLM = sub_qa_state["llm"] + response = model.invoke(BASE_CHECK_MESSAGE) + + start_time = datetime.datetime.now() + + return { + "sub_question_answer_check": response.content.lower(), + "base_answer_messages": f"{str(start_time)} - {str(end_time)}: base - qa_check", + } diff --git a/backend/danswer/agent_search/subgraph/states.py b/backend/danswer/agent_search/subgraph/states.py new file mode 100644 index 00000000000..d0f18fc21ea --- /dev/null +++ b/backend/danswer/agent_search/subgraph/states.py @@ -0,0 +1,71 @@ +import operator +from collections.abc import Sequence +from typing import Annotated +from typing import List +from typing import TypedDict + +from langchain_core.messages import BaseMessage +from langgraph.graph.message import add_messages + +from danswer.chat.models import DanswerContext +from danswer.llm.interfaces import LLM + + +class SubQuestionRetrieverState(TypedDict): + # The state for the parallel Retrievers. They each need to see only one query + sub_question_rewritten_query: str + + +class SubQuestionVerifierState(TypedDict): + # The state for the parallel verification step. Each node execution need to see only one question/doc pair + sub_question_document: DanswerContext + sub_question: str + + +class SubQAState(TypedDict): + # The 'core SubQuestion' state. + original_question: str + sub_question_rewritten_queries: List[str] + sub_question: str + sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_deduped_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_verified_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_reranked_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_top_chunks: Annotated[Sequence[DanswerContext], operator.add] + sub_question_answer: str + sub_question_answer_check: str + log_messages: Annotated[Sequence[BaseMessage], add_messages] + sub_qas: Annotated[ + Sequence[DanswerContext], operator.add + ] # Answers sent back to core + llm: LLM + tools: list[dict] + + +class SubQAOutputState(TypedDict): + # The 'SubQuestion' output state. Removes all the intermediate states + sub_question_rewritten_queries: List[str] + sub_question: str + sub_qas: Annotated[ + Sequence[DanswerContext], operator.add + ] # Answers sent back to core + sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_deduped_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_verified_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_reranked_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_top_chunks: Annotated[Sequence[DanswerContext], operator.add] + sub_question_answer: str + sub_question_answer_check: str + log_messages: Annotated[Sequence[BaseMessage], add_messages] diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 8a13bb8a74f..eb2cdd04b5a 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -26,9 +26,16 @@ huggingface-hub==0.20.1 jira==3.5.1 jsonref==1.1.0 trafilatura==1.12.2 -langchain==0.1.17 -langchain-core==0.1.50 -langchain-text-splitters==0.0.1 +langchain==0.3.7 +langchain-community==0.3.7 +langchain-core==0.3.20 +langchain-huggingface==0.1.2 +langchain-openai==0.2.9 +langchain-text-splitters==0.3.2 +langchainhub==0.1.21 +langgraph==0.2.53 +langgraph-checkpoint==2.0.5 +langgraph-sdk==0.1.36 litellm==1.53.1 lxml==5.3.0 lxml_html_clean==0.2.2 From 1be58e74b3470847c9ccd8268498acc94c166af3 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 6 Dec 2024 11:01:03 -0800 Subject: [PATCH 02/10] Finished primary graph --- .../agent_search/primary_graph/edges.py | 67 ++- .../primary_graph/graph_builder.py | 119 +++-- .../agent_search/primary_graph/nodes.py | 439 ++++++++++-------- .../agent_search/primary_graph/states.py | 35 +- backend/danswer/agent_search/run_graph.py | 3 +- .../shared_graph_utils/prompts.py | 438 ++++++++++++++--- .../agent_search/shared_graph_utils/utils.py | 76 +++ 7 files changed, 824 insertions(+), 353 deletions(-) diff --git a/backend/danswer/agent_search/primary_graph/edges.py b/backend/danswer/agent_search/primary_graph/edges.py index 577136b489b..238b92b52cf 100644 --- a/backend/danswer/agent_search/primary_graph/edges.py +++ b/backend/danswer/agent_search/primary_graph/edges.py @@ -1,34 +1,73 @@ +from collections.abc import Hashable +from typing import Union + +from langchain_core.messages import HumanMessage from langgraph.types import Send from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT -def continue_to_retrieval(state: QAState) -> list[Send]: +def continue_to_initial_sub_questions( + state: QAState, +) -> Union[Hashable, list[Hashable]]: # Routes re-written queries to the (parallel) retrieval steps # Notice the 'Send()' API that takes care of the parallelization return [ - Send("custom_retrieve", {"query": query}) - for query in state["rewritten_queries"] + Send( + "sub_answers_graph_initial", + { + "sub_question_str": initial_sub_question["sub_question_str"], + "sub_question_search_queries": initial_sub_question[ + "sub_question_search_queries" + ], + "sub_question_nr": initial_sub_question["sub_question_nr"], + "primary_llm": state["primary_llm"], + "fast_llm": state["fast_llm"], + "graph_start_time": state["graph_start_time"], + }, + ) + for initial_sub_question in state["initial_sub_questions"] ] -def continue_to_answer_sub_questions(state: QAState) -> list[Send]: +def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: # Routes re-written queries to the (parallel) retrieval steps # Notice the 'Send()' API that takes care of the parallelization return [ - Send("sub_answers_graph", {"base_answer_sub_question": sub_question}) - for sub_question in state["sub_questions"] + Send( + "sub_answers_graph", + { + "sub_question": sub_question, + "sub_question_nr": sub_question_nr, + "primary_llm": state["primary_llm"], + "fast_llm": state["fast_llm"], + "graph_start_time": state["graph_start_time"], + }, + ) + for sub_question_nr, sub_question in state["sub_questions"].items() ] -def continue_to_verifier(state: QAState) -> list[Send]: - # Routes each de-douped retrieved doc to the verifier step - in parallel - # Notice the 'Send()' API that takes care of the parallelization +def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]: + print("---GO TO DEEP ANSWER OR END---") - return [ - Send( - "verifier", - {"document": doc, "original_question": state["original_question"]}, + base_answer = state["base_answer"] + + question = state["original_question"] + + BASE_CHECK_MESSAGE = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) ) - for doc in state["deduped_retrieval_docs"] ] + + model = state["fast_llm"] + response = model.invoke(BASE_CHECK_MESSAGE) + + print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.content}") + + if response.content == "no": + return "decompose" + else: + return "end" diff --git a/backend/danswer/agent_search/primary_graph/graph_builder.py b/backend/danswer/agent_search/primary_graph/graph_builder.py index c429af17de1..90a0833d54e 100644 --- a/backend/danswer/agent_search/primary_graph/graph_builder.py +++ b/backend/danswer/agent_search/primary_graph/graph_builder.py @@ -3,114 +3,101 @@ from langgraph.graph import StateGraph from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions -from danswer.agent_search.primary_graph.edges import continue_to_retrieval -from danswer.agent_search.primary_graph.edges import continue_to_verifier -from danswer.agent_search.primary_graph.nodes import base_check +from danswer.agent_search.primary_graph.edges import continue_to_deep_answer +from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions +from danswer.agent_search.primary_graph.nodes import base_wait from danswer.agent_search.primary_graph.nodes import combine_retrieved_docs -from danswer.agent_search.primary_graph.nodes import consolidate_sub_qa from danswer.agent_search.primary_graph.nodes import custom_retrieve -from danswer.agent_search.primary_graph.nodes import decompose -from danswer.agent_search.primary_graph.nodes import deep_answer_generation +from danswer.agent_search.primary_graph.nodes import entity_term_extraction from danswer.agent_search.primary_graph.nodes import final_stuff -from danswer.agent_search.primary_graph.nodes import generate +from danswer.agent_search.primary_graph.nodes import generate_initial +from danswer.agent_search.primary_graph.nodes import main_decomp_base from danswer.agent_search.primary_graph.nodes import rewrite from danswer.agent_search.primary_graph.nodes import verifier from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.subgraph.graph_builder import build_subgraph def build_core_graph() -> StateGraph: # Define the nodes we will cycle between - coreAnswerGraph = StateGraph(QAState) + core_answer_graph = StateGraph(QAState) + + ### Add Nodes ### # Re-writing the question - coreAnswerGraph.add_node(node="rewrite", action=rewrite) + core_answer_graph.add_node(node="rewrite", action=rewrite) # The retrieval step - coreAnswerGraph.add_node(node="custom_retrieve", action=custom_retrieve) + core_answer_graph.add_node(node="custom_retrieve", action=custom_retrieve) # Combine and dedupe retrieved docs. - coreAnswerGraph.add_node( + core_answer_graph.add_node( node="combine_retrieved_docs", action=combine_retrieved_docs ) - # Verifying that a retrieved doc is relevant - coreAnswerGraph.add_node(node="verifier", action=verifier) - - sub_answers_graph = build_subgraph() - # Answering a sub-question - coreAnswerGraph.add_node(node="sub_answers_graph", action=sub_answers_graph) + # Extract entities, terms and relationships + core_answer_graph.add_node( + node="entity_term_extraction", action=entity_term_extraction + ) - # A final clean-up step - coreAnswerGraph.add_node(node="final_stuff", action=final_stuff) + # Verifying that a retrieved doc is relevant + core_answer_graph.add_node(node="verifier", action=verifier) - # Decomposing the question into sub-questions - coreAnswerGraph.add_node(node="decompose", action=decompose) + # Initial question decomposition + core_answer_graph.add_node(node="main_decomp_base", action=main_decomp_base) # Checking whether the initial answer is in the ballpark - coreAnswerGraph.add_node(node="base_check", action=base_check) + core_answer_graph.add_node(node="base_wait", action=base_wait) + + # A final clean-up step + core_answer_graph.add_node(node="final_stuff", action=final_stuff) # Generating a response after we know the documents are relevant - coreAnswerGraph.add_node(node="generate", action=generate) + core_answer_graph.add_node(node="generate_initial", action=generate_initial) - # Consolidating the sub-questions and answers - coreAnswerGraph.add_node(node="consolidate_sub_qa", action=consolidate_sub_qa) + ### Add Edges ### - # Generating a deep answer - coreAnswerGraph.add_node( - node="deep_answer_generation", action=deep_answer_generation + # start the initial sub-question decomposition + core_answer_graph.add_edge(start_key=START, end_key="main_decomp_base") + core_answer_graph.add_conditional_edges( + source="main_decomp_base", + path=continue_to_initial_sub_questions, + path_map={"sub_answers_graph_initial": "sub_answers_graph_initial"}, ) - ### Edges ### - - # start by rewriting the prompt - coreAnswerGraph.add_edge(start_key=START, end_key="rewrite") - - # Kick off another flow to decompose the question into sub-questions - coreAnswerGraph.add_edge(start_key=START, end_key="decompose") - - coreAnswerGraph.add_conditional_edges( - source="rewrite", - path=continue_to_retrieval, - path_map={"custom_retrieve": "custom_retrieve"}, + # use the retrieved information to generate the answer + core_answer_graph.add_edge( + start_key=["verifier", "sub_answers_graph_initial"], end_key="generate_initial" ) + core_answer_graph.add_edge(start_key="generate_initial", end_key="base_wait") - # check whether answer addresses the question - coreAnswerGraph.add_edge( - start_key="custom_retrieve", end_key="combine_retrieved_docs" + core_answer_graph.add_conditional_edges( + source="base_wait", + path=continue_to_deep_answer, + path_map={"decompose": "entity_term_extraction", "end": "final_stuff"}, ) - coreAnswerGraph.add_conditional_edges( - source="combine_retrieved_docs", - path=continue_to_verifier, - path_map={"verifier": "verifier"}, - ) + core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose") - coreAnswerGraph.add_conditional_edges( - source="decompose", + core_answer_graph.add_edge(start_key="decompose", end_key="sub_qa_manager") + core_answer_graph.add_conditional_edges( + source="sub_qa_manager", path=continue_to_answer_sub_questions, path_map={"sub_answers_graph": "sub_answers_graph"}, ) - # use the retrieved information to generate the answer - coreAnswerGraph.add_edge(start_key="verifier", end_key="generate") - - # check whether answer addresses the question - coreAnswerGraph.add_edge(start_key="generate", end_key="base_check") - - coreAnswerGraph.add_edge( - start_key="sub_answers_graph", end_key="consolidate_sub_qa" + core_answer_graph.add_edge( + start_key="sub_answers_graph", end_key="sub_qa_level_aggregator" ) - coreAnswerGraph.add_edge( - start_key="consolidate_sub_qa", end_key="deep_answer_generation" + core_answer_graph.add_edge( + start_key="sub_qa_level_aggregator", end_key="deep_answer_generation" ) - coreAnswerGraph.add_edge( - start_key=["base_check", "deep_answer_generation"], end_key="final_stuff" + core_answer_graph.add_edge( + start_key="deep_answer_generation", end_key="final_stuff" ) - coreAnswerGraph.add_edge(start_key="final_stuff", end_key=END) - coreAnswerGraph.compile() + core_answer_graph.add_edge(start_key="final_stuff", end_key=END) + core_answer_graph.compile() - return coreAnswerGraph + return core_answer_graph diff --git a/backend/danswer/agent_search/primary_graph/nodes.py b/backend/danswer/agent_search/primary_graph/nodes.py index 19e8b38e5f7..bb5621deabe 100644 --- a/backend/danswer/agent_search/primary_graph/nodes.py +++ b/backend/danswer/agent_search/primary_graph/nodes.py @@ -1,5 +1,7 @@ -import datetime import json +import re +from collections.abc import Sequence +from datetime import datetime from typing import Any from typing import Dict from typing import Literal @@ -10,18 +12,21 @@ from danswer.agent_search.primary_graph.states import QAState from danswer.agent_search.primary_graph.states import RetrieverState from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT -from danswer.agent_search.shared_graph_utils.prompts import DECOMPOSE_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace +from danswer.agent_search.shared_graph_utils.utils import generate_log_message from danswer.chat.models import DanswerContext from danswer.llm.interfaces import LLM +# Maybe try Partial[QAState] +# from typing import Partial + # Pydantic models for structured outputs class RewrittenQueries(BaseModel): @@ -37,7 +42,7 @@ class SubQuestions(BaseModel): # Transform the initial question into more suitable search queries. -def rewrite(qa_state: QAState) -> Dict[str, Any]: +def rewrite(state: QAState) -> Dict[str, Any]: """ Transform the initial question into more suitable search queries. @@ -47,12 +52,13 @@ def rewrite(qa_state: QAState) -> Dict[str, Any]: Returns: dict: The updated state with re-phrased question """ + print("---STARTING GRAPH---") + graph_start_time = datetime.now() print("---TRANSFORM QUERY---") + node_start_time = datetime.now() - start_time = datetime.datetime.now() - - question = qa_state["original_question"] + question = state["original_question"] msg = [ HumanMessage( @@ -61,26 +67,27 @@ def rewrite(qa_state: QAState) -> Dict[str, Any]: ] # Get the rewritten queries in a defined format - llm: LLM = qa_state["llm"] - tools: list[dict] = qa_state["tools"] - response = list( - llm.stream( + fast_llm: LLM = state["fast_llm"] + llm_response = list( + fast_llm.stream( prompt=msg, - tools=tools, structured_response_format=RewrittenQueries.model_json_schema(), ) ) - formatted_response: RewrittenQueries = json.loads(response[0].content) + formatted_response: RewrittenQueries = json.loads(llm_response[0].content) - end_time = datetime.datetime.now() return { "rewritten_queries": formatted_response.rewritten_queries, - "log_messages": f"{str(start_time)} - {str(end_time)}: core - rewrite", + "log_messages": generate_log_message( + message="core - rewrite", + node_start_time=node_start_time, + graph_start_time=graph_start_time, + ), } -def custom_retrieve(retriever_state: RetrieverState) -> Dict[str, Any]: +def custom_retrieve(state: RetrieverState) -> Dict[str, Any]: """ Retrieve documents @@ -92,41 +99,49 @@ def custom_retrieve(retriever_state: RetrieverState) -> Dict[str, Any]: """ print("---RETRIEVE---") - start_time = datetime.datetime.now() + node_start_time = datetime.now() - retriever_state["rewritten_query"] + # query = state["rewritten_query"] # Retrieval # TODO: add the actual retrieval, probably from search_tool.run() documents: list[DanswerContext] = [] - end_time = datetime.datetime.now() return { "base_retrieval_docs": documents, - "log_messages": f"{str(start_time)} - {str(end_time)}: core - custom_retrieve", + "log_messages": generate_log_message( + message="core - custom_retrieve", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), } -def combine_retrieved_docs(qa_state: QAState) -> Dict[str, Any]: +def combine_retrieved_docs(state: QAState) -> Dict[str, Any]: """ Dedupe the retrieved docs. """ - start_time = datetime.datetime.now() + node_start_time = datetime.now() - base_retrieval_docs = qa_state["base_retrieval_docs"] + base_retrieval_docs: Sequence[DanswerContext] = state["base_retrieval_docs"] print(f"Number of docs from steps: {len(base_retrieval_docs)}") - dedupe_docs = [] + dedupe_docs: list[DanswerContext] = [] for base_retrieval_doc in base_retrieval_docs: - if base_retrieval_doc not in dedupe_docs: + if not any( + base_retrieval_doc.document_id == doc.document_id for doc in dedupe_docs + ): dedupe_docs.append(base_retrieval_doc) print(f"Number of deduped docs: {len(dedupe_docs)}") - end_time = datetime.datetime.now() return { "deduped_retrieval_docs": dedupe_docs, - "log_messages": f"{str(start_time)} - {str(end_time)}: core - combine_retrieved_docs (dedupe)", + "log_messages": generate_log_message( + message="core - combine_retrieved_docs (dedupe)", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), } @@ -142,7 +157,7 @@ def verifier(state: VerifierState) -> Dict[str, Any]: """ print("---VERIFY QUTPUT---") - start_time = datetime.datetime.now() + node_start_time = datetime.now() question = state["original_question"] document_content = state["document"].content @@ -156,56 +171,45 @@ def verifier(state: VerifierState) -> Dict[str, Any]: ] # Grader - llm: LLM = state["llm"] - tools: list[dict] = state["tools"] + llm: LLM = state["fast_llm"] response = list( llm.stream( prompt=msg, - tools=tools, structured_response_format=BinaryDecision.model_json_schema(), ) ) formatted_response: BinaryDecision = response[0].content - end_time = datetime.datetime.now() - if formatted_response.decision == "yes": - end_time = datetime.datetime.now() - return { - "deduped_retrieval_docs": [state["document"]], - "log_messages": f"{str(start_time)} - {str(end_time)}: core - verifier: yes", - } - else: - end_time = datetime.datetime.now() - return { - "deduped_retrieval_docs": [], - "log_messages": f"{str(start_time)} - {str(end_time)}: core - verifier: no", - } + return { + "deduped_retrieval_docs": [state["document"]] + if formatted_response.decision == "yes" + else [], + "log_messages": generate_log_message( + message=f"core - verifier: {formatted_response.decision}", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } -def generate(qa_state: QAState) -> Dict[str, Any]: +def generate(state: QAState) -> Dict[str, Any]: """ Generate answer Args: - qa_state (messages): The current state + state (messages): The current state Returns: dict: The updated state with re-phrased question """ print("---GENERATE---") - start_time = datetime.datetime.now() + node_start_time = datetime.now() - question = qa_state["original_question"] - docs = qa_state["deduped_retrieval_docs"] - - print(f"Number of verified retrieval docs: {docs}") - - # LLM - llm: LLM = qa_state["llm"] + question = state["original_question"] + docs = state["deduped_retrieval_docs"] - # Chain - # rag_chain = BASE_RAG_PROMPT | llm | StrOutputParser() + print(f"Number of verified retrieval docs: {len(docs)}") msg = [ HumanMessage( @@ -214,12 +218,10 @@ def generate(qa_state: QAState) -> Dict[str, Any]: ] # Grader - llm: LLM = qa_state["llm"] - tools: list[dict] = qa_state["tools"] + llm: LLM = state["fast_llm"] response = list( llm.stream( prompt=msg, - tools=tools, structured_response_format=None, ) ) @@ -228,94 +230,78 @@ def generate(qa_state: QAState) -> Dict[str, Any]: # response = rag_chain.invoke({"context": docs, # "question": question}) - end_time = datetime.datetime.now() return { "base_answer": response[0].content, - "log_messages": f"{str(start_time)} - {str(end_time)}: core - generate", - } - - -def base_check(qa_state: QAState) -> Dict[str, Any]: - """ - Check whether the final output satisfies the original user question - - Args: - qa_state (messages): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---CHECK QUTPUT---") - start_time = datetime.datetime.now() - - # time.sleep(5) - - initial_base_answer = qa_state["initial_base_answer"] - - question = qa_state["original_question"] - - BASE_CHECK_MESSAGE = [ - HumanMessage( - content=BASE_CHECK_PROMPT.format( - question=question, base_answer=initial_base_answer - ) - ) - ] - - llm: LLM = qa_state["llm"] - tools: list[dict] = qa_state["tools"] - response = list( - llm.stream( - prompt=BASE_CHECK_MESSAGE, - tools=tools, - structured_response_format=None, - ) - ) - - print(f"Verdict: {response[0].content}") - - end_time = datetime.datetime.now() - return { - "base_answer": initial_base_answer, - "log_messages": f"{str(start_time)} - {str(end_time)}: core - base_check", + "log_messages": generate_log_message( + message="core - generate", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), } -def final_stuff(qa_state: QAState) -> Dict[str, Any]: +def final_stuff(state: QAState) -> Dict[str, Any]: """ Invokes the agent model to generate a response based on the current state. Given the question, it will decide to retrieve using the retriever tool, or simply end. Args: - qa_state (messages): The current state + state (messages): The current state Returns: dict: The updated state with the agent response appended to messages """ print("---FINAL---") - start_time = datetime.datetime.now() + node_start_time = datetime.now() - messages = qa_state["log_messages"] + messages = state["log_messages"] time_ordered_messages = [x.content for x in messages] time_ordered_messages.sort() print("Message Log:") print("\n".join(time_ordered_messages)) - end_time = datetime.datetime.now() - print(f"{str(start_time)} - {str(end_time)}: all - final_stuff") + initial_sub_qas = state["initial_sub_qas"] + initial_sub_qa_list = [] + for initial_sub_qa in initial_sub_qas: + if initial_sub_qa["sub_answer_check"] == "yes": + initial_sub_qa_list.append( + f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----' + ) + + initial_sub_qa_context = "\n".join(initial_sub_qa_list) + + log_message = generate_log_message( + message="all - final_stuff", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ) + + print(log_message) + print("--------------------------------") + + base_answer = state["base_answer"] + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}") print("--------------------------------") - base_answer = qa_state["base_answer"] - deep_answer = qa_state["deep_answer"] - sub_qas = qa_state["checked_sub_qas"] + if not state.get("deep_answer"): + print("No Deep Answer was required") + return { + "log_messages": log_message, + } + + deep_answer = state["deep_answer"] + sub_qas = state["sub_qas"] sub_qa_list = [] for sub_qa in sub_qas: - sub_qa_list.append( - f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----' - ) + if sub_qa["sub_answer_check"] == "yes": + sub_qa_list.append( + f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----' + ) + sub_qa_context = "\n".join(sub_qa_list) print(f"Final Base Answer:\n{base_answer}") @@ -326,137 +312,194 @@ def final_stuff(qa_state: QAState) -> Dict[str, Any]: print(sub_qa_context) return { - "log_messages": f"{str(start_time)} - {str(end_time)}: all - final_stuff", + "log_messages": generate_log_message( + message="all - final_stuff", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), } -# nodes - - -def decompose(qa_state: QAState) -> Dict[str, Any]: +def base_wait(state: QAState) -> Dict[str, Any]: """ - Decompose a complex question into simpler sub-questions. + Ensures that all required steps are completed before proceeding to the next step Args: - qa_state: The current QA state containing the original question and LLM + state (messages): The current state Returns: - Dict containing sub_questions and log messages + dict: {} (no operation, just logging) """ - start_time = datetime.datetime.now() + print("---Base Wait ---") + node_start_time = datetime.now() + return { + "log_messages": generate_log_message( + message="core - base_wait", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def entity_term_extraction(state: QAState) -> Dict[str, Any]: + """ """ - question = qa_state["original_question"] + node_start_time = datetime.now() + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + doc_context = format_docs(docs) msg = [ HumanMessage( - content=DECOMPOSE_PROMPT.format(question=question), + content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), ) ] # Grader - llm: LLM = qa_state["llm"] - tools: list[dict] = qa_state["tools"] - response = list( - llm.stream( - prompt=msg, - tools=tools, - structured_response_format=SubQuestions.model_json_schema(), - ) - ) + model = state["fast_llm"] + response = model.invoke(msg) - formatted_response: SubQuestions = response[0].content + cleaned_response = re.sub(r"```json\n|\n```", "", response.content) + parsed_response = json.loads(cleaned_response) - end_time = datetime.datetime.now() return { - "sub_questions": formatted_response.sub_questions, - "log_messages": f"{str(start_time)} - {str(end_time)}: deep - decompose", + "retrieved_entities_relationships": parsed_response, + "log_messages": generate_log_message( + message="deep - entity term extraction", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), } -# aggregate sub questions and answers -def consolidate_sub_qa(qa_state: QAState) -> Dict[str, Any]: +def generate_initial(state: QAState) -> Dict[str, Any]: """ - Consolidate sub-questions and their answers. + Generate answer Args: - qa_state: The current QA state containing sub QAs + state (messages): The current state Returns: - Dict containing dynamic context, checked sub QAs and log messages + dict: The updated state with re-phrased question """ - sub_qas = qa_state["sub_qas"] + print("---GENERATE INITIAL---") + node_start_time = datetime.now() - start_time = datetime.datetime.now() + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + print(f"Number of verified retrieval docs - base: {len(docs)}") - dynamic_context_list = [ - "Below you will find useful information to answer the original question:" - ] - checked_sub_qas = [] + sub_question_answers = state["initial_sub_qas"] - for sub_qa in sub_qas: - question = sub_qa["sub_question"] - answer = sub_qa["sub_answer"] - verified = sub_qa["sub_answer_check"] + sub_question_answers_list = [] - if verified == "yes": - dynamic_context_list.append( - f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n" + _SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + for sub_question_answer_dict in sub_question_answers: + if ( + sub_question_answer_dict["sub_answer_check"] == "yes" + and len(sub_question_answer_dict["sub_answer"]) > 0 + and sub_question_answer_dict["sub_answer"] != "I don't know" + ): + sub_question_answers_list.append( + _SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=sub_question_answer_dict["sub_question"], + sub_answer=sub_question_answer_dict["sub_answer"], + ) ) - checked_sub_qas.append({"sub_question": question, "sub_answer": answer}) - dynamic_context = "\n".join(dynamic_context_list) - end_time = datetime.datetime.now() + sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list) + + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT.format( + question=question, + context=format_docs(docs), + answered_sub_questions=sub_question_answer_str, + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + # Run + # response = rag_chain.invoke({"context": docs, + # "question": question}) + return { - "dynamic_context": dynamic_context, - "checked_sub_qas": checked_sub_qas, - "log_messages": f"{str(start_time)} - {str(end_time)}: deep - consolidate_sub_qa", + "base_answer": response.content, + "log_messages": generate_log_message( + message="core - generate initial", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), } -# aggregate sub questions and answers -def deep_answer_generation(qa_state: QAState) -> Dict[str, Any]: +def main_decomp_base(state: QAState) -> Dict[str, Any]: """ - Generate answer + Perform an initial question decomposition, incl. one search term Args: - qa_state (messages): The current state + state (messages): The current state Returns: - dict: The updated state with re-phrased question + dict: The updated state with initial decomposition """ - print("---GENERATE---") - start_time = datetime.datetime.now() - question = qa_state["original_question"] - docs = qa_state["deduped_retrieval_docs"] + print("---INITIAL DECOMP---") + node_start_time = datetime.now() - deep_answer_context = qa_state["dynamic_context"] - - print(f"Number of verified retrieval docs: {docs}") + question = state["original_question"] - combined_context = normalize_whitespace( - COMBINED_CONTEXT.format( - deep_answer_context=deep_answer_context, formated_docs=format_docs(docs) + msg = [ + HumanMessage( + content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), ) - ) - + ] + """ msg = [ HumanMessage( - content=MODIFIED_RAG_PROMPT.format( - question=question, combined_context=combined_context - ) + content=INITIAL_DECOMPOSITION_PROMPT_BASIC.format(question=question), ) ] + """ - # Grader - # LLM - model: LLM = qa_state["llm"] + # Get the rewritten queries in a defined format + model = state["fast_llm"] response = model.invoke(msg) - end_time = datetime.datetime.now() + content = response.content + list_of_subquestions = clean_and_parse_list_string(content) + # response = model.invoke(msg) + + decomp_list = [] + + for sub_question_nr, sub_question in enumerate(list_of_subquestions): + sub_question_str = sub_question["sub_question"].strip() + # temporarily + sub_question_search_queries = [sub_question["search_term"]] + + decomp_list.append( + { + "sub_question_str": sub_question_str, + "sub_question_search_queries": sub_question_search_queries, + "sub_question_nr": sub_question_nr, + } + ) + return { - "final_deep_answer": response.content, - "log_messages": f"{str(start_time)} - {str(end_time)}: deep - deep_answer_generation", + "initial_sub_questions": decomp_list, + "start_time_temp": node_start_time, + "log_messages": generate_log_message( + message="core - initial decomp", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), } - # return {"log_messages": [response]} diff --git a/backend/danswer/agent_search/primary_graph/states.py b/backend/danswer/agent_search/primary_graph/states.py index 4729d5f8f63..90e44d636b7 100644 --- a/backend/danswer/agent_search/primary_graph/states.py +++ b/backend/danswer/agent_search/primary_graph/states.py @@ -1,7 +1,7 @@ import operator from collections.abc import Sequence +from datetime import datetime from typing import Annotated -from typing import List from typing import TypedDict from langchain_core.messages import BaseMessage @@ -14,37 +14,48 @@ class QAState(TypedDict): # The 'main' state of the answer graph original_question: str + graph_start_time: datetime + sub_query_start_time: datetime # start time for parallel initial sub-questionn thread log_messages: Annotated[Sequence[BaseMessage], add_messages] - rewritten_queries: List[str] - sub_questions: List[str] + rewritten_queries: list[str] + sub_questions: list[dict] + initial_sub_questions: list[dict] + ranked_subquestion_ids: list[int] + decomposed_sub_questions_dict: dict + rejected_sub_questions: Annotated[list[str], operator.add] + rejected_sub_questions_handled: bool sub_qas: Annotated[Sequence[dict], operator.add] + initial_sub_qas: Annotated[Sequence[dict], operator.add] checked_sub_qas: Annotated[Sequence[dict], operator.add] base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] deduped_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] reranked_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] retrieved_entities_relationships: dict - questions_context: List[dict] + questions_context: list[dict] qa_level: int - top_chunks: List[DanswerContext] + top_chunks: list[DanswerContext] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] + num_new_question_iterations: int + core_answer_dynamic_context: str dynamic_context: str initial_base_answer: str base_answer: str deep_answer: str - llm: LLM - tools: list[dict] + primary_llm: LLM + fast_llm: LLM class QAOuputState(TypedDict): # The 'main' output state of the answer graph. Removes all the intermediate states original_question: str log_messages: Annotated[Sequence[BaseMessage], add_messages] - sub_questions: List[str] + sub_questions: list[dict] sub_qas: Annotated[Sequence[dict], operator.add] + initial_sub_qas: Annotated[Sequence[dict], operator.add] checked_sub_qas: Annotated[Sequence[dict], operator.add] reranked_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] retrieved_entities_relationships: dict - top_chunks: List[DanswerContext] + top_chunks: list[DanswerContext] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] base_answer: str deep_answer: str @@ -53,9 +64,15 @@ class QAOuputState(TypedDict): class RetrieverState(TypedDict): # The state for the parallel Retrievers. They each need to see only one query rewritten_query: str + primary_llm: LLM + fast_llm: LLM + graph_start_time: datetime class VerifierState(TypedDict): # The state for the parallel verification step. Each node execution need to see only one question/doc pair document: DanswerContext original_question: str + primary_llm: LLM + fast_llm: LLM + graph_start_time: datetime diff --git a/backend/danswer/agent_search/run_graph.py b/backend/danswer/agent_search/run_graph.py index bbee5ce0270..02c14a64438 100644 --- a/backend/danswer/agent_search/run_graph.py +++ b/backend/danswer/agent_search/run_graph.py @@ -17,5 +17,6 @@ def run_graph( "tools": tools, "llm": llm, } - output = graph.invoke(input=inputs) + compiled_graph = graph.compile() + output = compiled_graph.invoke(input=inputs) yield from output diff --git a/backend/danswer/agent_search/shared_graph_utils/prompts.py b/backend/danswer/agent_search/shared_graph_utils/prompts.py index 3cf27167ae3..9bc1789cd4b 100644 --- a/backend/danswer/agent_search/shared_graph_utils/prompts.py +++ b/backend/danswer/agent_search/shared_graph_utils/prompts.py @@ -1,49 +1,126 @@ -REWRITE_PROMPT_MULTI = """\n - Please convert an initial user question into a 2-3 more appropriate - search queries for retrievel from a document store.\n +REWRITE_PROMPT_MULTI = """ \n + Please convert an initial user question into a 2-3 more appropriate search queries for retrievel from a + document store. \n Here is the initial question: \n ------- \n {question} \n ------- \n - Formulate the query: -""" + Formulate the query: """ + +INITIAL_DECOMPOSITION_PROMPT = """ \n + Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + For each sub-question, please also create one search term that can be used to retrieve relevant + documents from a document store. + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of json objects with the following format: + + [{{"sub_question": , "search_term": }}, ...] + + Answer: + """ + +INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n + Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n -REWRITE_PROMPT_SINGLE = """\n - Please convert an initial user question into a more appropriate search - query for retrievel from a document store.\n Here is the initial question: \n ------- \n {question} \n ------- \n - Formulate the query: -""" + Please formulate your answer as a list of subquestions: + + Answer: + """ +REWRITE_PROMPT_SINGLE = """ \n + Please convert an initial user question into a more appropriate search query for retrievel from a + document store. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Formulate the query: """ + +BASE_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the context provided below - and only the + provided context - to answer the question. If you don't know the answer or if the provided context is + empty, just say "I don't know". Do not use your internal knowledge! + + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! -BASE_RAG_PROMPT = """\n - You are an assistant for question-answering tasks. - Use the following pieces of retrieved context to answer the question. - If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. - \nQuestion: {question} - \nContext: {context} - \nAnswer: -""" + answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n + \n\n + Answer:""" + +INITIAL_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the information provided below - and only the + provided information - to answer the provided question. + + The information provided below consists of: + 1) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 2) a number of documents that were also deemed relevant for the question. + + If you don't know the answer or if the provided information is empty or insufficient, just say + "I don't know". Do not use your internal knowledge! + + Again, only use the provided informationand do not use your internal knowledge! It is a matter of life + and death that you do NOT use your internal knowledge, just the provided information! + + Try to keep your answer concise. + + And here is the question and the provided information: + \n + \nQuestion:\n {question} + + \nAnswered Sub-questions:\n {answered_sub_questions} + + \nContext:\n {context} \n\n + \n\n + + Answer:""" -MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. - Use the following pieces of retrieved context to answer the question. - If you don't know the answer, just say that you don't know. +MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below + - and only this context - to answer the question. If you don't know the answer, just say "I don't know". Use three sentences maximum and keep the answer concise. - Pay also particular attention to the sub-questions and their answers, - at least it may enrich the answer. + Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer. + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + \nQuestion: {question} - \nContext: {combined_context} - \nAnswer: -""" + \nContext: {combined_context} \n + + Answer:""" -BASE_CHECK_PROMPT = """\n - Please check whether the suggested answer seems to address the original question. Please only answer with 'yes' or 'no'\n +BASE_CHECK_PROMPT = """ \n + Please check whether 1) the suggested answer seems to fully address the original question AND 2)the + original question requests a simple, factual answer, and there are no ambiguities, judgements, + aggregations, or any other complications that may require extra context. (I.e., if the question is + somewhat addressed, but the answer would benefit from more context, then answer with 'no'.) + + Please only answer with 'yes' or 'no' \n Here is the initial question: \n ------- \n {question} @@ -52,13 +129,25 @@ \n ------- \n {base_answer} \n ------- \n - Please answer with yes or no: -""" + Please answer with yes or no:""" +SUB_CHECK_PROMPT = """ \n + Please check whether the suggested answer seems to address the original question. -VERIFIER_PROMPT = """\n - Please check whether the document seems to be relevant for the - answer of the original question. Please only answer with 'yes' or 'no'\n + Please only answer with 'yes' or 'no' \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the proposed answer: + \n ------- \n + {base_answer} + \n ------- \n + Please answer with yes or no:""" + +VERIFIER_PROMPT = """ \n + Please check whether the document seems to be relevant for the answer of the original question. Please + only answer with 'yes' or 'no' \n Here is the initial question: \n ------- \n {question} @@ -67,56 +156,275 @@ \n ------- \n {document_content} \n ------- \n - Please answer with yes or no: -""" + Please answer with yes or no:""" + +ENTITY_TERM_PROMPT = """ \n + Based on the original question and the context retieved from a dataset, please generate a list of + entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts + (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. + + \n\n + Here is the original question: + \n ------- \n + {question} + \n ------- \n + And here is the context retrieved: + \n ------- \n + {context} + \n ------- \n + + Please format your answer as a json object in the following format: + + {{"retrieved_entities_relationships": {{ + "entities": [{{ + "entity_name": , + "entity_type": + }}], + "relationships": [{{ + "name": , + "type": , + "entities": [, ] + }}], + "terms": [{{ + "term_name": , + "term_type": , + "similar_to": + }}] + }} + }} + """ + +ORIG_DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 3-5 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": , "search_term": }}, ...] + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n -DECOMPOSE_PROMPT = """\n - For an initial user question, please generate at least two but not - more than 3 individual sub-questions whose answers would help\n - to answer the initial question. The individual questions should be - answerable by a good RAG system. So a good idea would be to\n - use the sub-questions to resolve ambiguities and/or to separate the + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Please think through it step by step and then generate the list of json dictionaries with the following + format: + + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ + +DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 4-6 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please also provide a search term that can be used to retrieve relevant + documents from a document store. + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Generate the list of json dictionaries with the following format: + + {{"sub_questions": [{{"sub_question": , + "search_term": }}, + ...]}} """ + +DECOMPOSE_PROMPT = """ \n + For an initial user question, please generate at 5-10 individual sub-questions whose answers would help + \n to answer the initial question. The individual questions should be answerable by a good RAG system. + So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the question for different entities that may be involved in the original question. + In order to arrive at meaningful sub-questions, please also consider the context retrieved from the + document store, expressed as entities, relationships and terms. You can also think about the types + mentioned in brackets + Guidelines: - - The sub-questions should be specific to the question and provide - richer context for the question, and or resolve ambiguities - - Each sub-question - when answered - should be relevant for the - answer to the original question - - The sub-questions MUST have the full context of the original - question so that it can be executed by a RAG system independently - without the original question available + - The sub-questions should be specific to the question and provide richer context for the question, + and or resolve ambiguities + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available (Example: - initial question: "What is the capital of France?" - bad sub-question: "What is the name of the river there?" - - good sub-question: "What is the name of the river that flows - through Paris?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": , "search_term": }}, ...] \n\n Here is the initial question: \n ------- \n {question} \n ------- \n - Please generate the list of good, fully contextualized sub-questions. - Think through it step by step and then generate the list. -""" + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Don't be too specific unless the original question is specific. + Please think through it step by step and then generate the list of json dictionaries with the following + format: + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ #### Consolidations -COMBINED_CONTEXT = """ - ------- - Below you will find useful information to answer the original question. - First, you see a number of sub-questions with their answers. - This information should be considered to be more focussed and - somewhat more specific to the original question as it tries to - contextualized facts. - After that will see the documents that were considered to be - relevant to answer the original question. +COMBINED_CONTEXT = """------- + Below you will find useful information to answer the original question. First, you see a number of + sub-questions with their answers. This information should be considered to be more focussed and + somewhat more specific to the original question as it tries to contextualized facts. + After that will see the documents that were considered to be relevant to answer the original question. Here are the sub-questions and their answers: \n\n {deep_answer_context} \n\n - \n\n Here are the documents that were considered to be relevant to - answer the original question: + \n\n Here are the documents that were considered to be relevant to answer the original question: \n\n {formated_docs} \n\n ---------------- -""" + """ + +SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- + Below you will find a question that we ultimately want to answer (the original question) and a list of + motivations in arbitrary order for generated sub-questions that are supposed to help us answering the + original question. The motivations are formatted as : . + (Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant + motivation and 2 is less relevant.) + + Please rank the motivations in order of relevance for answering the original question. Also, try to + ensure that the top questions do not duplicate too much, i.e. that they are not too similar. + Ultimately, create a list with the motivation numbers where the number of the most relevant + motivations comes first. + + Here is the original question: + \n\n {original_question} \n\n + \n\n Here is the list of sub-question motivations: + \n\n {sub_question_explanations} \n\n + ---------------- + + Please think step by step and then generate the ranked list of motivations. + + Please format your answer as a json object in the following format: + {{"reasonning": , + "ranked_motivations": }} + """ diff --git a/backend/danswer/agent_search/shared_graph_utils/utils.py b/backend/danswer/agent_search/shared_graph_utils/utils.py index 1a3383ccd80..4bfd9fa87a8 100644 --- a/backend/danswer/agent_search/shared_graph_utils/utils.py +++ b/backend/danswer/agent_search/shared_graph_utils/utils.py @@ -1,4 +1,10 @@ +import ast +import json +import re from collections.abc import Sequence +from datetime import datetime +from datetime import timedelta +from typing import Any from danswer.chat.models import DanswerContext @@ -13,3 +19,73 @@ def normalize_whitespace(text: str) -> str: # Post-processing def format_docs(docs: Sequence[DanswerContext]) -> str: return "\n\n".join(doc.content for doc in docs) + + +def clean_and_parse_list_string(json_string: str) -> list[dict]: + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + # Parse the cleaned string into a Python dictionary + return ast.literal_eval(cleaned_string) + + +def clean_and_parse_json_string(json_string: str) -> dict[str, Any]: + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + # Parse the cleaned string into a Python dictionary + return json.loads(cleaned_string) + + +def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str: + entities = entity_term_extraction_dict["entities"] + terms = entity_term_extraction_dict["terms"] + relationships = entity_term_extraction_dict["relationships"] + + entity_strs = ["\nEntities:\n"] + for entity in entities: + entity_str = f"{entity['entity_name']} ({entity['entity_type']})" + entity_strs.append(entity_str) + + entity_str = "\n - ".join(entity_strs) + + relationship_strs = ["\n\nRelationships:\n"] + for relationship in relationships: + relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}" + relationship_strs.append(relationship_str) + + relationship_str = "\n - ".join(relationship_strs) + + term_strs = ["\n\nTerms:\n"] + for term in terms: + term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}" + term_strs.append(term_str) + + term_str = "\n - ".join(term_strs) + + return "\n".join(entity_strs + relationship_strs + term_strs) + + +def _format_time_delta(time: timedelta) -> str: + seconds_from_start = f"{((time).seconds):03d}" + microseconds_from_start = f"{((time).microseconds):06d}" + return f"{seconds_from_start}.{microseconds_from_start}" + + +def generate_log_message( + message: str, + node_start_time: datetime, + graph_start_time: datetime | None = None, +) -> str: + current_time = datetime.now() + + if graph_start_time is not None: + graph_time_str = _format_time_delta(current_time - graph_start_time) + else: + graph_time_str = "N/A" + + node_time_str = _format_time_delta(current_time - node_start_time) + + return f"{graph_time_str} ({node_time_str} s): {message}" From 617726207b236064301ed1d0eb754edb60d2c906 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Sat, 7 Dec 2024 06:06:22 -0800 Subject: [PATCH 03/10] all 3 graphs r done --- .../agent_search/base_qa_sub_graph/edges.py | 44 +++ .../base_qa_sub_graph/graph_builder.py | 104 ++++++ .../agent_search/base_qa_sub_graph/nodes.py | 306 +++++++++++++++++ .../agent_search/base_qa_sub_graph/prompts.py | 13 + .../{subgraph => base_qa_sub_graph}/states.py | 42 ++- .../agent_search/primary_graph/edges.py | 38 ++- .../primary_graph/graph_builder.py | 2 +- .../agent_search/primary_graph/nodes.py | 57 ++-- .../agent_search/primary_graph/prompts.py | 86 +++++ .../agent_search/primary_graph/states.py | 5 +- .../research_qa_sub_graph/edges.py | 46 +++ .../research_qa_sub_graph/graph_builder.py | 87 +++++ .../research_qa_sub_graph/nodes.py | 308 +++++++++++++++++ .../research_qa_sub_graph/states.py | 60 ++++ .../agent_search/shared_graph_utils/models.py | 16 + .../shared_graph_utils/prompts.py | 177 +++------- .../danswer/agent_search/subgraph/edges.py | 25 -- .../agent_search/subgraph/graph_builder.py | 59 ---- .../danswer/agent_search/subgraph/nodes.py | 316 ------------------ 19 files changed, 1179 insertions(+), 612 deletions(-) create mode 100644 backend/danswer/agent_search/base_qa_sub_graph/edges.py create mode 100644 backend/danswer/agent_search/base_qa_sub_graph/graph_builder.py create mode 100644 backend/danswer/agent_search/base_qa_sub_graph/nodes.py create mode 100644 backend/danswer/agent_search/base_qa_sub_graph/prompts.py rename backend/danswer/agent_search/{subgraph => base_qa_sub_graph}/states.py (65%) create mode 100644 backend/danswer/agent_search/primary_graph/prompts.py create mode 100644 backend/danswer/agent_search/research_qa_sub_graph/edges.py create mode 100644 backend/danswer/agent_search/research_qa_sub_graph/graph_builder.py create mode 100644 backend/danswer/agent_search/research_qa_sub_graph/nodes.py create mode 100644 backend/danswer/agent_search/research_qa_sub_graph/states.py create mode 100644 backend/danswer/agent_search/shared_graph_utils/models.py delete mode 100644 backend/danswer/agent_search/subgraph/edges.py delete mode 100644 backend/danswer/agent_search/subgraph/graph_builder.py delete mode 100644 backend/danswer/agent_search/subgraph/nodes.py diff --git a/backend/danswer/agent_search/base_qa_sub_graph/edges.py b/backend/danswer/agent_search/base_qa_sub_graph/edges.py new file mode 100644 index 00000000000..08118446f38 --- /dev/null +++ b/backend/danswer/agent_search/base_qa_sub_graph/edges.py @@ -0,0 +1,44 @@ +from collections.abc import Hashable +from typing import Union + +from langgraph.types import Send + +from danswer.agent_search.base_qa_sub_graph.states import BaseQAState +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.primary_graph.states import VerifierState + + +def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]: + # Routes each de-douped retrieved doc to the verifier step - in parallel + # Notice the 'Send()' API that takes care of the parallelization + + return [ + Send( + "sub_verifier", + VerifierState( + document=doc, + question=state["sub_question_str"], + fast_llm=state["fast_llm"], + primary_llm=state["primary_llm"], + graph_start_time=state["graph_start_time"], + ), + ) + for doc in state["sub_question_base_retrieval_docs"] + ] + + +def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]: + # Routes re-written queries to the (parallel) retrieval steps + # Notice the 'Send()' API that takes care of the parallelization + return [ + Send( + "sub_custom_retrieve", + RetrieverState( + rewritten_query=query, + primary_llm=state["primary_llm"], + fast_llm=state["fast_llm"], + graph_start_time=state["graph_start_time"], + ), + ) + for query in state["sub_question_search_queries"] + ] diff --git a/backend/danswer/agent_search/base_qa_sub_graph/graph_builder.py b/backend/danswer/agent_search/base_qa_sub_graph/graph_builder.py new file mode 100644 index 00000000000..62943930c0f --- /dev/null +++ b/backend/danswer/agent_search/base_qa_sub_graph/graph_builder.py @@ -0,0 +1,104 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.agent_search.base_qa_sub_graph.edges import sub_continue_to_retrieval +from danswer.agent_search.base_qa_sub_graph.edges import sub_continue_to_verifier +from danswer.agent_search.base_qa_sub_graph.nodes import sub_combine_retrieved_docs +from danswer.agent_search.base_qa_sub_graph.nodes import sub_custom_retrieve +from danswer.agent_search.base_qa_sub_graph.nodes import sub_dummy +from danswer.agent_search.base_qa_sub_graph.nodes import sub_final_format +from danswer.agent_search.base_qa_sub_graph.nodes import sub_generate +from danswer.agent_search.base_qa_sub_graph.nodes import sub_qa_check +from danswer.agent_search.base_qa_sub_graph.nodes import sub_verifier +from danswer.agent_search.base_qa_sub_graph.states import BaseQAOutputState +from danswer.agent_search.base_qa_sub_graph.states import BaseQAState + +# from danswer.agent_search.base_qa_sub_graph.nodes import sub_rewrite + + +def build_base_qa_sub_graph() -> StateGraph: + sub_answers_initial = StateGraph( + state_schema=BaseQAState, + output=BaseQAOutputState, + ) + + ### Add nodes ### + # sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite) + sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy) + sub_answers_initial.add_node( + node="sub_custom_retrieve", + action=sub_custom_retrieve, + ) + sub_answers_initial.add_node( + node="sub_combine_retrieved_docs", + action=sub_combine_retrieved_docs, + ) + sub_answers_initial.add_node( + node="sub_verifier", + action=sub_verifier, + ) + sub_answers_initial.add_node( + node="sub_generate", + action=sub_generate, + ) + sub_answers_initial.add_node( + node="sub_qa_check", + action=sub_qa_check, + ) + sub_answers_initial.add_node( + node="sub_final_format", + action=sub_final_format, + ) + + ### Add edges ### + sub_answers_initial.add_edge(START, "sub_rewrite") + + sub_answers_initial.add_conditional_edges( + source="sub_rewrite", + path=sub_continue_to_retrieval, + path_map=["sub_custom_retrieve"], + ) + + sub_answers_initial.add_edge( + start_key="sub_custom_retrieve", + end_key="sub_combine_retrieved_docs", + ) + + sub_answers_initial.add_conditional_edges( + source="sub_combine_retrieved_docs", + path=sub_continue_to_verifier, + path_map=["sub_verifier"], + ) + + sub_answers_initial.add_edge( + start_key="sub_verifier", + end_key="sub_generate", + ) + + sub_answers_initial.add_edge( + start_key="sub_generate", + end_key="sub_qa_check", + ) + + sub_answers_initial.add_edge( + start_key="sub_qa_check", + end_key="sub_final_format", + ) + + sub_answers_initial.add_edge( + start_key="sub_final_format", + end_key=END, + ) + # sub_answers_graph = sub_answers_initial.compile() + return sub_answers_initial + + +if __name__ == "__main__": + # TODO: add the actual question + inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"} + sub_answers_graph = build_base_qa_sub_graph() + compiled_sub_answers = sub_answers_graph.compile() + output = compiled_sub_answers.invoke(inputs) + print("\nOUTPUT:") + print(output) diff --git a/backend/danswer/agent_search/base_qa_sub_graph/nodes.py b/backend/danswer/agent_search/base_qa_sub_graph/nodes.py new file mode 100644 index 00000000000..e616ef906eb --- /dev/null +++ b/backend/danswer/agent_search/base_qa_sub_graph/nodes.py @@ -0,0 +1,306 @@ +import datetime +import json +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.base_qa_sub_graph.states import BaseQAState +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries +from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.chat.models import DanswerContext +from danswer.llm.interfaces import LLM + + +# unused at this point. Kept from tutorial. But here the agent makes a routing decision +# not that much of an agent if the routing is static... +def sub_rewrite(state: BaseQAState) -> dict[str, Any]: + """ + Transform the initial question into more suitable search queries. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + print("---SUB TRANSFORM QUERY---") + + node_start_time = datetime.datetime.now() + + # messages = state["base_answer_messages"] + question = state["sub_question_str"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + fast_llm: LLM = state["fast_llm"] + llm_response = list( + fast_llm.stream( + prompt=msg, + structured_response_format=RewrittenQueries.model_json_schema(), + ) + ) + + # Get the rewritten queries in a defined format + rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) + + print(f"rewritten_queries: {rewritten_queries}") + + rewritten_queries = RewrittenQueries( + rewritten_queries=[ + "music hard to listen to", + "Music that is not fun or pleasant", + ] + ) + + print(f"hardcoded rewritten_queries: {rewritten_queries}") + + return { + "sub_question_rewritten_queries": rewritten_queries, + "log_messages": generate_log_message( + message="sub - rewrite", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +# dummy node to report on state if needed +def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE SUB---") + + node_start_time = datetime.datetime.now() + + # rewritten_query = state["rewritten_query"] + + # Retrieval + # TODO: add the actual retrieval, probably from search_tool.run() + documents: list[DanswerContext] = [] + + return { + "sub_question_base_retrieval_docs": documents, + "log_messages": generate_log_message( + message="sub - custom_retrieve", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]: + """ + Dedupe the retrieved docs. + """ + node_start_time = datetime.datetime.now() + + sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] + + print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") + dedupe_docs = [] + for base_retrieval_doc in sub_question_base_retrieval_docs: + if base_retrieval_doc not in dedupe_docs: + dedupe_docs.append(base_retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + return { + "sub_question_deduped_retrieval_docs": dedupe_docs, + "log_messages": generate_log_message( + message="sub - combine_retrieved_docs (dedupe)", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_verifier(state: VerifierState) -> dict[str, Any]: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---VERIFY QUTPUT---") + node_start_time = datetime.datetime.now() + + question = state["question"] + document_content = state["document"].content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + # Grader + llm: LLM = state["primary_llm"] + response = list( + llm.stream( + prompt=msg, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + raw_response = json.loads(response[0].pretty_repr()) + formatted_response = BinaryDecision.model_validate(raw_response) + + print(f"Verdict: {formatted_response.decision}") + + return { + "sub_question_verified_retrieval_docs": [state["document"]] + if formatted_response.decision == "yes" + else [], + "log_messages": generate_log_message( + message=f"sub - verifier: {formatted_response.decision}", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_generate(state: BaseQAState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---GENERATE---") + start_time = datetime.datetime.now() + + question = state["original_question"] + docs = state["sub_question_verified_retrieval_docs"] + + print(f"Number of verified retrieval docs: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + ) + ] + + # Grader + llm: LLM = state["primary_llm"] + response = list( + llm.stream( + prompt=msg, + structured_response_format=None, + ) + ) + + answer = response[0].pretty_repr() + return { + "sub_question_answer": answer, + "log_messages": generate_log_message( + message="base - generate", + node_start_time=start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_final_format(state: BaseQAState) -> dict[str, Any]: + """ + Create the final output for the QA subgraph + """ + + print("---BASE FINAL FORMAT---") + datetime.datetime.now() + + return { + "sub_qas": [ + { + "sub_question": state["sub_question_str"], + "sub_answer": state["sub_question_answer"], + "sub_answer_check": state["sub_question_answer_check"], + } + ], + "log_messages": state["log_messages"], + } + + +def sub_qa_check(state: BaseQAState) -> dict[str, Any]: + """ + Check if the sub-question answer is satisfactory. + + Args: + state: The current SubQAState containing the sub-question and its answer + + Returns: + dict containing the check result and log message + """ + + msg = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format( + question=state["sub_question_str"], + base_answer=state["sub_question_answer"], + ) + ) + ] + + model: LLM = state["primary_llm"] + response = list( + model.stream( + prompt=msg, + structured_response_format=None, + ) + ) + + start_time = datetime.datetime.now() + + return { + "sub_question_answer_check": response[0].pretty_repr().lower(), + "base_answer_messages": generate_log_message( + message="sub - qa_check", + node_start_time=start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_dummy(state: BaseQAState) -> dict[str, Any]: + """ + Dummy step + """ + + print("---Sub Dummy---") + + node_start_time = datetime.datetime.now() + + return { + "log_messages": generate_log_message( + message="sub - dummy", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/base_qa_sub_graph/prompts.py b/backend/danswer/agent_search/base_qa_sub_graph/prompts.py new file mode 100644 index 00000000000..4e983873b70 --- /dev/null +++ b/backend/danswer/agent_search/base_qa_sub_graph/prompts.py @@ -0,0 +1,13 @@ +SUB_CHECK_PROMPT = """ \n + Please check whether the suggested answer seems to address the original question. + + Please only answer with 'yes' or 'no' \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the proposed answer: + \n ------- \n + {base_answer} + \n ------- \n + Please answer with yes or no:""" diff --git a/backend/danswer/agent_search/subgraph/states.py b/backend/danswer/agent_search/base_qa_sub_graph/states.py similarity index 65% rename from backend/danswer/agent_search/subgraph/states.py rename to backend/danswer/agent_search/base_qa_sub_graph/states.py index d0f18fc21ea..f14533035d8 100644 --- a/backend/danswer/agent_search/subgraph/states.py +++ b/backend/danswer/agent_search/base_qa_sub_graph/states.py @@ -1,7 +1,7 @@ import operator from collections.abc import Sequence +from datetime import datetime from typing import Annotated -from typing import List from typing import TypedDict from langchain_core.messages import BaseMessage @@ -22,11 +22,16 @@ class SubQuestionVerifierState(TypedDict): sub_question: str -class SubQAState(TypedDict): +class BaseQAState(TypedDict): # The 'core SubQuestion' state. original_question: str - sub_question_rewritten_queries: List[str] - sub_question: str + graph_start_time: datetime + # start time for parallel initial sub-questionn thread + sub_query_start_time: datetime + sub_question_rewritten_queries: list[str] + sub_question_str: str + sub_question_search_queries: list[str] + sub_question_nr: int sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] sub_question_deduped_retrieval_docs: Annotated[ Sequence[DanswerContext], operator.add @@ -37,24 +42,27 @@ class SubQAState(TypedDict): sub_question_reranked_retrieval_docs: Annotated[ Sequence[DanswerContext], operator.add ] - sub_question_top_chunks: Annotated[Sequence[DanswerContext], operator.add] + sub_question_top_chunks: Annotated[Sequence[dict], operator.add] sub_question_answer: str sub_question_answer_check: str log_messages: Annotated[Sequence[BaseMessage], add_messages] - sub_qas: Annotated[ - Sequence[DanswerContext], operator.add - ] # Answers sent back to core - llm: LLM - tools: list[dict] + sub_qas: Annotated[Sequence[dict], operator.add] + # Answers sent back to core + initial_sub_qas: Annotated[Sequence[dict], operator.add] + primary_llm: LLM + fast_llm: LLM -class SubQAOutputState(TypedDict): +class BaseQAOutputState(TypedDict): # The 'SubQuestion' output state. Removes all the intermediate states - sub_question_rewritten_queries: List[str] - sub_question: str - sub_qas: Annotated[ - Sequence[DanswerContext], operator.add - ] # Answers sent back to core + sub_question_rewritten_queries: list[str] + sub_question_str: str + sub_question_search_queries: list[str] + sub_question_nr: int + # Answers sent back to core + sub_qas: Annotated[Sequence[dict], operator.add] + # Answers sent back to core + initial_sub_qas: Annotated[Sequence[dict], operator.add] sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] sub_question_deduped_retrieval_docs: Annotated[ Sequence[DanswerContext], operator.add @@ -65,7 +73,7 @@ class SubQAOutputState(TypedDict): sub_question_reranked_retrieval_docs: Annotated[ Sequence[DanswerContext], operator.add ] - sub_question_top_chunks: Annotated[Sequence[DanswerContext], operator.add] + sub_question_top_chunks: Annotated[Sequence[dict], operator.add] sub_question_answer: str sub_question_answer_check: str log_messages: Annotated[Sequence[BaseMessage], add_messages] diff --git a/backend/danswer/agent_search/primary_graph/edges.py b/backend/danswer/agent_search/primary_graph/edges.py index 238b92b52cf..a21709ec3b3 100644 --- a/backend/danswer/agent_search/primary_graph/edges.py +++ b/backend/danswer/agent_search/primary_graph/edges.py @@ -4,7 +4,9 @@ from langchain_core.messages import HumanMessage from langgraph.types import Send +from danswer.agent_search.base_qa_sub_graph.states import BaseQAState from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT @@ -16,16 +18,16 @@ def continue_to_initial_sub_questions( return [ Send( "sub_answers_graph_initial", - { - "sub_question_str": initial_sub_question["sub_question_str"], - "sub_question_search_queries": initial_sub_question[ + BaseQAState( + sub_question_str=initial_sub_question["sub_question_str"], + sub_question_search_queries=initial_sub_question[ "sub_question_search_queries" ], - "sub_question_nr": initial_sub_question["sub_question_nr"], - "primary_llm": state["primary_llm"], - "fast_llm": state["fast_llm"], - "graph_start_time": state["graph_start_time"], - }, + sub_question_nr=initial_sub_question["sub_question_nr"], + primary_llm=state["primary_llm"], + fast_llm=state["fast_llm"], + graph_start_time=state["graph_start_time"], + ), ) for initial_sub_question in state["initial_sub_questions"] ] @@ -37,15 +39,15 @@ def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Has return [ Send( "sub_answers_graph", - { - "sub_question": sub_question, - "sub_question_nr": sub_question_nr, - "primary_llm": state["primary_llm"], - "fast_llm": state["fast_llm"], - "graph_start_time": state["graph_start_time"], - }, + ResearchQAState( + sub_question=sub_question["sub_question_str"], + sub_question_nr=sub_question["sub_question_nr"], + graph_start_time=state["graph_start_time"], + primary_llm=state["primary_llm"], + fast_llm=state["fast_llm"], + ), ) - for sub_question_nr, sub_question in state["sub_questions"].items() + for sub_question in state["sub_questions"] ] @@ -65,9 +67,9 @@ def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]: model = state["fast_llm"] response = model.invoke(BASE_CHECK_MESSAGE) - print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.content}") + print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}") - if response.content == "no": + if response.pretty_repr() == "no": return "decompose" else: return "end" diff --git a/backend/danswer/agent_search/primary_graph/graph_builder.py b/backend/danswer/agent_search/primary_graph/graph_builder.py index 90a0833d54e..20ce65c5588 100644 --- a/backend/danswer/agent_search/primary_graph/graph_builder.py +++ b/backend/danswer/agent_search/primary_graph/graph_builder.py @@ -19,7 +19,7 @@ def build_core_graph() -> StateGraph: # Define the nodes we will cycle between - core_answer_graph = StateGraph(QAState) + core_answer_graph = StateGraph(state_schema=QAState) ### Add Nodes ### diff --git a/backend/danswer/agent_search/primary_graph/nodes.py b/backend/danswer/agent_search/primary_graph/nodes.py index bb5621deabe..45fc14358b0 100644 --- a/backend/danswer/agent_search/primary_graph/nodes.py +++ b/backend/danswer/agent_search/primary_graph/nodes.py @@ -3,15 +3,14 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from typing import Dict -from typing import Literal from langchain_core.messages import HumanMessage -from pydantic import BaseModel from danswer.agent_search.primary_graph.states import QAState from danswer.agent_search.primary_graph.states import RetrieverState from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from danswer.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT from danswer.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT @@ -28,21 +27,8 @@ # from typing import Partial -# Pydantic models for structured outputs -class RewrittenQueries(BaseModel): - rewritten_queries: list[str] - - -class BinaryDecision(BaseModel): - decision: Literal["yes", "no"] - - -class SubQuestions(BaseModel): - sub_questions: list[str] - - # Transform the initial question into more suitable search queries. -def rewrite(state: QAState) -> Dict[str, Any]: +def rewrite(state: QAState) -> dict[str, Any]: """ Transform the initial question into more suitable search queries. @@ -75,7 +61,7 @@ def rewrite(state: QAState) -> Dict[str, Any]: ) ) - formatted_response: RewrittenQueries = json.loads(llm_response[0].content) + formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) return { "rewritten_queries": formatted_response.rewritten_queries, @@ -87,7 +73,7 @@ def rewrite(state: QAState) -> Dict[str, Any]: } -def custom_retrieve(state: RetrieverState) -> Dict[str, Any]: +def custom_retrieve(state: RetrieverState) -> dict[str, Any]: """ Retrieve documents @@ -117,7 +103,7 @@ def custom_retrieve(state: RetrieverState) -> Dict[str, Any]: } -def combine_retrieved_docs(state: QAState) -> Dict[str, Any]: +def combine_retrieved_docs(state: QAState) -> dict[str, Any]: """ Dedupe the retrieved docs. """ @@ -145,7 +131,7 @@ def combine_retrieved_docs(state: QAState) -> Dict[str, Any]: } -def verifier(state: VerifierState) -> Dict[str, Any]: +def verifier(state: VerifierState) -> dict[str, Any]: """ Check whether the document is relevant for the original user question @@ -159,7 +145,7 @@ def verifier(state: VerifierState) -> Dict[str, Any]: print("---VERIFY QUTPUT---") node_start_time = datetime.now() - question = state["original_question"] + question = state["question"] document_content = state["document"].content msg = [ @@ -179,7 +165,8 @@ def verifier(state: VerifierState) -> Dict[str, Any]: ) ) - formatted_response: BinaryDecision = response[0].content + raw_response = json.loads(response[0].pretty_repr()) + formatted_response = BinaryDecision.model_validate(raw_response) return { "deduped_retrieval_docs": [state["document"]] @@ -193,7 +180,7 @@ def verifier(state: VerifierState) -> Dict[str, Any]: } -def generate(state: QAState) -> Dict[str, Any]: +def generate(state: QAState) -> dict[str, Any]: """ Generate answer @@ -231,7 +218,7 @@ def generate(state: QAState) -> Dict[str, Any]: # "question": question}) return { - "base_answer": response[0].content, + "base_answer": response[0].pretty_repr(), "log_messages": generate_log_message( message="core - generate", node_start_time=node_start_time, @@ -240,7 +227,7 @@ def generate(state: QAState) -> Dict[str, Any]: } -def final_stuff(state: QAState) -> Dict[str, Any]: +def final_stuff(state: QAState) -> dict[str, Any]: """ Invokes the agent model to generate a response based on the current state. Given the question, it will decide to retrieve using the retriever tool, or simply end. @@ -255,7 +242,7 @@ def final_stuff(state: QAState) -> Dict[str, Any]: node_start_time = datetime.now() messages = state["log_messages"] - time_ordered_messages = [x.content for x in messages] + time_ordered_messages = [x.pretty_repr() for x in messages] time_ordered_messages.sort() print("Message Log:") @@ -320,7 +307,7 @@ def final_stuff(state: QAState) -> Dict[str, Any]: } -def base_wait(state: QAState) -> Dict[str, Any]: +def base_wait(state: QAState) -> dict[str, Any]: """ Ensures that all required steps are completed before proceeding to the next step @@ -342,7 +329,7 @@ def base_wait(state: QAState) -> Dict[str, Any]: } -def entity_term_extraction(state: QAState) -> Dict[str, Any]: +def entity_term_extraction(state: QAState) -> dict[str, Any]: """ """ node_start_time = datetime.now() @@ -362,7 +349,7 @@ def entity_term_extraction(state: QAState) -> Dict[str, Any]: model = state["fast_llm"] response = model.invoke(msg) - cleaned_response = re.sub(r"```json\n|\n```", "", response.content) + cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr()) parsed_response = json.loads(cleaned_response) return { @@ -375,7 +362,7 @@ def entity_term_extraction(state: QAState) -> Dict[str, Any]: } -def generate_initial(state: QAState) -> Dict[str, Any]: +def generate_initial(state: QAState) -> dict[str, Any]: """ Generate answer @@ -433,7 +420,7 @@ def generate_initial(state: QAState) -> Dict[str, Any]: # "question": question}) return { - "base_answer": response.content, + "base_answer": response.pretty_repr(), "log_messages": generate_log_message( message="core - generate initial", node_start_time=node_start_time, @@ -442,7 +429,7 @@ def generate_initial(state: QAState) -> Dict[str, Any]: } -def main_decomp_base(state: QAState) -> Dict[str, Any]: +def main_decomp_base(state: QAState) -> dict[str, Any]: """ Perform an initial question decomposition, incl. one search term @@ -475,7 +462,7 @@ def main_decomp_base(state: QAState) -> Dict[str, Any]: model = state["fast_llm"] response = model.invoke(msg) - content = response.content + content = response.pretty_repr() list_of_subquestions = clean_and_parse_list_string(content) # response = model.invoke(msg) @@ -496,7 +483,7 @@ def main_decomp_base(state: QAState) -> Dict[str, Any]: return { "initial_sub_questions": decomp_list, - "start_time_temp": node_start_time, + "sub_query_start_time": node_start_time, "log_messages": generate_log_message( message="core - initial decomp", node_start_time=node_start_time, diff --git a/backend/danswer/agent_search/primary_graph/prompts.py b/backend/danswer/agent_search/primary_graph/prompts.py new file mode 100644 index 00000000000..0eafd70d275 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/prompts.py @@ -0,0 +1,86 @@ +INITIAL_DECOMPOSITION_PROMPT = """ \n + Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + For each sub-question, please also create one search term that can be used to retrieve relevant + documents from a document store. + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of json objects with the following format: + + [{{"sub_question": , "search_term": }}, ...] + + Answer: + """ + +INITIAL_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the information provided below - and only the + provided information - to answer the provided question. + + The information provided below consists of: + 1) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 2) a number of documents that were also deemed relevant for the question. + + If you don't know the answer or if the provided information is empty or insufficient, just say + "I don't know". Do not use your internal knowledge! + + Again, only use the provided informationand do not use your internal knowledge! It is a matter of life + and death that you do NOT use your internal knowledge, just the provided information! + + Try to keep your answer concise. + + And here is the question and the provided information: + \n + \nQuestion:\n {question} + + \nAnswered Sub-questions:\n {answered_sub_questions} + + \nContext:\n {context} \n\n + \n\n + + Answer:""" + +ENTITY_TERM_PROMPT = """ \n + Based on the original question and the context retieved from a dataset, please generate a list of + entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts + (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. + + \n\n + Here is the original question: + \n ------- \n + {question} + \n ------- \n + And here is the context retrieved: + \n ------- \n + {context} + \n ------- \n + + Please format your answer as a json object in the following format: + + {{"retrieved_entities_relationships": {{ + "entities": [{{ + "entity_name": , + "entity_type": + }}], + "relationships": [{{ + "name": , + "type": , + "entities": [, ] + }}], + "terms": [{{ + "term_name": , + "term_type": , + "similar_to": + }}] + }} + }} + """ diff --git a/backend/danswer/agent_search/primary_graph/states.py b/backend/danswer/agent_search/primary_graph/states.py index 90e44d636b7..b692ca4cf1c 100644 --- a/backend/danswer/agent_search/primary_graph/states.py +++ b/backend/danswer/agent_search/primary_graph/states.py @@ -15,7 +15,8 @@ class QAState(TypedDict): # The 'main' state of the answer graph original_question: str graph_start_time: datetime - sub_query_start_time: datetime # start time for parallel initial sub-questionn thread + # start time for parallel initial sub-questionn thread + sub_query_start_time: datetime log_messages: Annotated[Sequence[BaseMessage], add_messages] rewritten_queries: list[str] sub_questions: list[dict] @@ -72,7 +73,7 @@ class RetrieverState(TypedDict): class VerifierState(TypedDict): # The state for the parallel verification step. Each node execution need to see only one question/doc pair document: DanswerContext - original_question: str + question: str primary_llm: LLM fast_llm: LLM graph_start_time: datetime diff --git a/backend/danswer/agent_search/research_qa_sub_graph/edges.py b/backend/danswer/agent_search/research_qa_sub_graph/edges.py new file mode 100644 index 00000000000..48afc58206e --- /dev/null +++ b/backend/danswer/agent_search/research_qa_sub_graph/edges.py @@ -0,0 +1,46 @@ +from collections.abc import Hashable +from typing import Union + +from langgraph.types import Send + +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState + + +def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]: + # Routes each de-douped retrieved doc to the verifier step - in parallel + # Notice the 'Send()' API that takes care of the parallelization + + return [ + Send( + "sub_verifier", + VerifierState( + document=doc, + question=state["sub_question"], + primary_llm=state["primary_llm"], + fast_llm=state["fast_llm"], + graph_start_time=state["graph_start_time"], + ), + ) + for doc in state["sub_question_base_retrieval_docs"] + ] + + +def sub_continue_to_retrieval( + state: ResearchQAState, +) -> Union[Hashable, list[Hashable]]: + # Routes re-written queries to the (parallel) retrieval steps + # Notice the 'Send()' API that takes care of the parallelization + return [ + Send( + "sub_custom_retrieve", + RetrieverState( + rewritten_query=query, + primary_llm=state["primary_llm"], + fast_llm=state["fast_llm"], + graph_start_time=state["graph_start_time"], + ), + ) + for query in state["sub_question_rewritten_queries"] + ] diff --git a/backend/danswer/agent_search/research_qa_sub_graph/graph_builder.py b/backend/danswer/agent_search/research_qa_sub_graph/graph_builder.py new file mode 100644 index 00000000000..40ea5580ad8 --- /dev/null +++ b/backend/danswer/agent_search/research_qa_sub_graph/graph_builder.py @@ -0,0 +1,87 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.agent_search.research_qa_sub_graph.edges import sub_continue_to_retrieval +from danswer.agent_search.research_qa_sub_graph.edges import sub_continue_to_verifier +from danswer.agent_search.research_qa_sub_graph.nodes import sub_combine_retrieved_docs +from danswer.agent_search.research_qa_sub_graph.nodes import sub_custom_retrieve +from danswer.agent_search.research_qa_sub_graph.nodes import sub_dummy +from danswer.agent_search.research_qa_sub_graph.nodes import sub_final_format +from danswer.agent_search.research_qa_sub_graph.nodes import sub_generate +from danswer.agent_search.research_qa_sub_graph.nodes import sub_qa_check +from danswer.agent_search.research_qa_sub_graph.nodes import sub_verifier +from danswer.agent_search.research_qa_sub_graph.states import ResearchQAOutputState +from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState + +# from danswer.agent_search.research_qa_sub_graph.nodes import sub_rewrite + + +def build_base_qa_sub_graph() -> StateGraph: + # Define the nodes we will cycle between + sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState) + + ### Add Nodes ### + + # Dummy node for initial processing + sub_answers.add_node(node="sub_dummy", action=sub_dummy) + + # The retrieval step + sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve) + + # The dedupe step + sub_answers.add_node( + node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs + ) + + # Verifying retrieved information + sub_answers.add_node(node="sub_verifier", action=sub_verifier) + + # Generating the response + sub_answers.add_node(node="sub_generate", action=sub_generate) + + # Checking the quality of the answer + sub_answers.add_node(node="sub_qa_check", action=sub_qa_check) + + # Final formatting of the response + sub_answers.add_node(node="sub_final_format", action=sub_final_format) + + ### Add Edges ### + + sub_answers.add_edge(start_key=START, end_key="sub_rewrite") + + sub_answers.add_conditional_edges( + source="sub_rewrite", + path=sub_continue_to_retrieval, + path_map=["sub_custom_retrieve"], + ) + + sub_answers.add_edge( + start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs" + ) + + sub_answers.add_conditional_edges( + source="sub_combine_retrieved_docs", + path=sub_continue_to_verifier, + path_map=["sub_verifier"], + ) + + sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate") + + sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check") + + sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format") + + sub_answers.add_edge(start_key="sub_final_format", end_key=END) + + return sub_answers + + +if __name__ == "__main__": + # TODO: add the actual question + inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"} + sub_answers_graph = build_base_qa_sub_graph() + compiled_sub_answers = sub_answers_graph.compile() + output = compiled_sub_answers.invoke(inputs) + print("\nOUTPUT:") + print(output) diff --git a/backend/danswer/agent_search/research_qa_sub_graph/nodes.py b/backend/danswer/agent_search/research_qa_sub_graph/nodes.py new file mode 100644 index 00000000000..85401a5c8fa --- /dev/null +++ b/backend/danswer/agent_search/research_qa_sub_graph/nodes.py @@ -0,0 +1,308 @@ +import json +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.base_qa_sub_graph.states import BaseQAState +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.chat.models import DanswerContext +from danswer.llm.interfaces import LLM + + +def sub_rewrite(state: ResearchQAState) -> dict[str, Any]: + """ + Transform the initial question into more suitable search queries. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + print("---SUB TRANSFORM QUERY---") + node_start_time = datetime.now() + + question = state["sub_question"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + fast_llm: LLM = state["fast_llm"] + llm_response = list( + fast_llm.stream( + prompt=msg, + structured_response_format=RewrittenQueries.model_json_schema(), + ) + ) + + # Get the rewritten queries in a defined format + rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) + + print(f"rewritten_queries: {rewritten_queries}") + + rewritten_queries = RewrittenQueries( + rewritten_queries=[ + "music hard to listen to", + "Music that is not fun or pleasant", + ] + ) + + print(f"hardcoded rewritten_queries: {rewritten_queries}") + + return { + "sub_question_rewritten_queries": rewritten_queries, + "log_messages": generate_log_message( + message="sub - rewrite", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE SUB---") + node_start_time = datetime.now() + + # Retrieval + # TODO: add the actual retrieval, probably from search_tool.run() + documents: list[DanswerContext] = [] + + return { + "sub_question_base_retrieval_docs": documents, + "log_messages": generate_log_message( + message="sub - custom_retrieve", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]: + """ + Dedupe the retrieved docs. + """ + node_start_time = datetime.now() + + sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] + + print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") + dedupe_docs = [] + for base_retrieval_doc in sub_question_base_retrieval_docs: + if base_retrieval_doc not in dedupe_docs: + dedupe_docs.append(base_retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + return { + "sub_question_deduped_retrieval_docs": dedupe_docs, + "log_messages": generate_log_message( + message="sub - combine_retrieved_docs (dedupe)", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_verifier(state: VerifierState) -> dict[str, Any]: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---SUB VERIFY QUTPUT---") + node_start_time = datetime.now() + + question = state["question"] + document_content = state["document"].content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + if response.pretty_repr().lower() == "yes": + return { + "sub_question_verified_retrieval_docs": [state["document"]], + "log_messages": generate_log_message( + message="sub - verifier: yes", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + else: + return { + "sub_question_verified_retrieval_docs": [], + "log_messages": generate_log_message( + message="sub - verifier: no", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_generate(state: ResearchQAState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---SUB GENERATE---") + node_start_time = datetime.now() + + question = state["sub_question"] + docs = state["sub_question_verified_retrieval_docs"] + + print(f"Number of verified retrieval docs for sub-question: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + ) + ] + + # Grader + if len(docs) > 0: + model = state["fast_llm"] + response = model.invoke(msg).pretty_repr() + else: + response = "" + + return { + "sub_question_answer": response, + "log_messages": generate_log_message( + message="sub - generate", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_final_format(state: ResearchQAState) -> dict[str, Any]: + """ + Create the final output for the QA subgraph + """ + + print("---SUB FINAL FORMAT---") + node_start_time = datetime.now() + + return { + # TODO: Type this + "sub_qas": [ + { + "sub_question": state["sub_question"], + "sub_answer": state["sub_question_answer"], + "sub_question_nr": state["sub_question_nr"], + "sub_answer_check": state["sub_question_answer_check"], + } + ], + "log_messages": generate_log_message( + message="sub - final format", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +# nodes + + +def sub_qa_check(state: ResearchQAState) -> dict[str, Any]: + """ + Check whether the final output satisfies the original user question + + Args: + state (messages): The current state + + Returns: + dict: The updated state with the final decision + """ + + print("---CHECK SUB QUTPUT---") + node_start_time = datetime.now() + + sub_answer = state["sub_question_answer"] + sub_question = state["sub_question"] + + msg = [ + HumanMessage( + content=SUB_CHECK_PROMPT.format( + sub_question=sub_question, sub_answer=sub_answer + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = list( + model.stream( + prompt=msg, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + raw_response = json.loads(response[0].pretty_repr()) + formatted_response = BinaryDecision.model_validate(raw_response) + + return { + "sub_question_answer_check": formatted_response.decision, + "log_messages": generate_log_message( + message=f"sub - qa check: {formatted_response.decision}", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } + + +def sub_dummy(state: BaseQAState) -> dict[str, Any]: + """ + Dummy step + """ + + print("---Sub Dummy---") + + return { + "log_messages": generate_log_message( + message="sub - dummy", + node_start_time=datetime.now(), + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/research_qa_sub_graph/states.py b/backend/danswer/agent_search/research_qa_sub_graph/states.py new file mode 100644 index 00000000000..abc0d3b04f1 --- /dev/null +++ b/backend/danswer/agent_search/research_qa_sub_graph/states.py @@ -0,0 +1,60 @@ +import operator +from collections.abc import Sequence +from datetime import datetime +from typing import Annotated +from typing import TypedDict + +from langchain_core.messages import BaseMessage +from langgraph.graph.message import add_messages + +from danswer.chat.models import DanswerContext +from danswer.llm.interfaces import LLM + + +class ResearchQAState(TypedDict): + # The 'core SubQuestion' state. + original_question: str + graph_start_time: datetime + sub_question_rewritten_queries: list[str] + sub_question: str + sub_question_nr: int + sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_deduped_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_verified_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_reranked_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_top_chunks: Annotated[Sequence[dict], operator.add] + sub_question_answer: str + sub_question_answer_check: str + log_messages: Annotated[Sequence[BaseMessage], add_messages] + sub_qas: Annotated[Sequence[dict], operator.add] + primary_llm: LLM + fast_llm: LLM + + +class ResearchQAOutputState(TypedDict): + # The 'SubQuestion' output state. Removes all the intermediate states + sub_question_rewritten_queries: list[str] + sub_question: str + sub_question_nr: int + # Answers sent back to core + sub_qas: Annotated[Sequence[dict], operator.add] + sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_deduped_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_verified_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_reranked_retrieval_docs: Annotated[ + Sequence[DanswerContext], operator.add + ] + sub_question_top_chunks: Annotated[Sequence[dict], operator.add] + sub_question_answer: str + sub_question_answer_check: str + log_messages: Annotated[Sequence[BaseMessage], add_messages] diff --git a/backend/danswer/agent_search/shared_graph_utils/models.py b/backend/danswer/agent_search/shared_graph_utils/models.py new file mode 100644 index 00000000000..ed731fbc566 --- /dev/null +++ b/backend/danswer/agent_search/shared_graph_utils/models.py @@ -0,0 +1,16 @@ +from typing import Literal + +from pydantic import BaseModel + + +# Pydantic models for structured outputs +class RewrittenQueries(BaseModel): + rewritten_queries: list[str] + + +class BinaryDecision(BaseModel): + decision: Literal["yes", "no"] + + +class SubQuestions(BaseModel): + sub_questions: list[str] diff --git a/backend/danswer/agent_search/shared_graph_utils/prompts.py b/backend/danswer/agent_search/shared_graph_utils/prompts.py index 9bc1789cd4b..e52234d40d2 100644 --- a/backend/danswer/agent_search/shared_graph_utils/prompts.py +++ b/backend/danswer/agent_search/shared_graph_utils/prompts.py @@ -8,57 +8,6 @@ Formulate the query: """ -INITIAL_DECOMPOSITION_PROMPT = """ \n - Please decompose an initial user question into not more than 4 appropriate sub-questions that help to - answer the original question. The purpose for this decomposition is to isolate individulal entities - (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales - for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our - sales with company A' + 'what is our market share with company A' + 'is company A a reference customer - for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n - - For each sub-question, please also create one search term that can be used to retrieve relevant - documents from a document store. - - Here is the initial question: - \n ------- \n - {question} - \n ------- \n - - Please formulate your answer as a list of json objects with the following format: - - [{{"sub_question": , "search_term": }}, ...] - - Answer: - """ - -INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n - Please decompose an initial user question into not more than 4 appropriate sub-questions that help to - answer the original question. The purpose for this decomposition is to isolate individulal entities - (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales - for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our - sales with company A' + 'what is our market share with company A' + 'is company A a reference customer - for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n - - Here is the initial question: - \n ------- \n - {question} - \n ------- \n - - Please formulate your answer as a list of subquestions: - - Answer: - """ - -REWRITE_PROMPT_SINGLE = """ \n - Please convert an initial user question into a more appropriate search query for retrievel from a - document store. \n - Here is the initial question: - \n ------- \n - {question} - \n ------- \n - - Formulate the query: """ - BASE_RAG_PROMPT = """ \n You are an assistant for question-answering tasks. Use the context provided below - and only the provided context - to answer the question. If you don't know the answer or if the provided context is @@ -73,47 +22,6 @@ \n\n Answer:""" -INITIAL_RAG_PROMPT = """ \n - You are an assistant for question-answering tasks. Use the information provided below - and only the - provided information - to answer the provided question. - - The information provided below consists of: - 1) a number of answered sub-questions - these are very important(!) and definitely should be - considered to answer the question. - 2) a number of documents that were also deemed relevant for the question. - - If you don't know the answer or if the provided information is empty or insufficient, just say - "I don't know". Do not use your internal knowledge! - - Again, only use the provided informationand do not use your internal knowledge! It is a matter of life - and death that you do NOT use your internal knowledge, just the provided information! - - Try to keep your answer concise. - - And here is the question and the provided information: - \n - \nQuestion:\n {question} - - \nAnswered Sub-questions:\n {answered_sub_questions} - - \nContext:\n {context} \n\n - \n\n - - Answer:""" - -MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below - - and only this context - to answer the question. If you don't know the answer, just say "I don't know". - Use three sentences maximum and keep the answer concise. - Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer. - Again, only use the provided context and do not use your internal knowledge! If you cannot answer the - question based on the context, say "I don't know". It is a matter of life and death that you do NOT - use your internal knowledge, just the provided information! - - \nQuestion: {question} - \nContext: {combined_context} \n - - Answer:""" - BASE_CHECK_PROMPT = """ \n Please check whether 1) the suggested answer seems to fully address the original question AND 2)the original question requests a simple, factual answer, and there are no ambiguities, judgements, @@ -131,20 +39,6 @@ \n ------- \n Please answer with yes or no:""" -SUB_CHECK_PROMPT = """ \n - Please check whether the suggested answer seems to address the original question. - - Please only answer with 'yes' or 'no' \n - Here is the initial question: - \n ------- \n - {question} - \n ------- \n - Here is the proposed answer: - \n ------- \n - {base_answer} - \n ------- \n - Please answer with yes or no:""" - VERIFIER_PROMPT = """ \n Please check whether the document seems to be relevant for the answer of the original question. Please only answer with 'yes' or 'no' \n @@ -158,43 +52,48 @@ \n ------- \n Please answer with yes or no:""" -ENTITY_TERM_PROMPT = """ \n - Based on the original question and the context retieved from a dataset, please generate a list of - entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts - (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. +INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n + Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n - \n\n - Here is the original question: + Here is the initial question: \n ------- \n {question} \n ------- \n - And here is the context retrieved: + + Please formulate your answer as a list of subquestions: + + Answer: + """ + +REWRITE_PROMPT_SINGLE = """ \n + Please convert an initial user question into a more appropriate search query for retrievel from a + document store. \n + Here is the initial question: \n ------- \n - {context} + {question} \n ------- \n - Please format your answer as a json object in the following format: + Formulate the query: """ + +MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below + - and only this context - to answer the question. If you don't know the answer, just say "I don't know". + Use three sentences maximum and keep the answer concise. + Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer. + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + \nQuestion: {question} + \nContext: {combined_context} \n + + Answer:""" - {{"retrieved_entities_relationships": {{ - "entities": [{{ - "entity_name": , - "entity_type": - }}], - "relationships": [{{ - "name": , - "type": , - "entities": [, ] - }}], - "terms": [{{ - "term_name": , - "term_type": , - "similar_to": - }}] - }} - }} - """ - -ORIG_DEEP_DECOMPOSE_PROMPT = """ \n +_ORIG_DEEP_DECOMPOSE_PROMPT = """ \n An initial user question needs to be answered. An initial answer has been provided but it wasn't quite good enough. Also, some sub-questions had been answered and this information has been used to provide the initial answer. Some other subquestions may have been suggested based on little knowledge, but they @@ -271,7 +170,7 @@ "search_term": }}, ...]}} """ -DEEP_DECOMPOSE_PROMPT = """ \n +_DEEP_DECOMPOSE_PROMPT = """ \n An initial user question needs to be answered. An initial answer has been provided but it wasn't quite good enough. Also, some sub-questions had been answered and this information has been used to provide the initial answer. Some other subquestions may have been suggested based on little knowledge, but they @@ -343,7 +242,7 @@ "search_term": }}, ...]}} """ -DECOMPOSE_PROMPT = """ \n +_DECOMPOSE_PROMPT = """ \n For an initial user question, please generate at 5-10 individual sub-questions whose answers would help \n to answer the initial question. The individual questions should be answerable by a good RAG system. So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the @@ -391,7 +290,7 @@ ...]}} """ #### Consolidations -COMBINED_CONTEXT = """------- +_COMBINED_CONTEXT = """------- Below you will find useful information to answer the original question. First, you see a number of sub-questions with their answers. This information should be considered to be more focussed and somewhat more specific to the original question as it tries to contextualized facts. @@ -404,7 +303,7 @@ ---------------- """ -SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- +_SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- Below you will find a question that we ultimately want to answer (the original question) and a list of motivations in arbitrary order for generated sub-questions that are supposed to help us answering the original question. The motivations are formatted as : . diff --git a/backend/danswer/agent_search/subgraph/edges.py b/backend/danswer/agent_search/subgraph/edges.py deleted file mode 100644 index 2d293a5a34d..00000000000 --- a/backend/danswer/agent_search/subgraph/edges.py +++ /dev/null @@ -1,25 +0,0 @@ -from langgraph.types import Send - -from danswer.agent_search.primary_graph.states import QAState - - -def sub_continue_to_verifier(qa_state: QAState) -> list[Send]: - # Routes each de-douped retrieved doc to the verifier step - in parallel - # Notice the 'Send()' API that takes care of the parallelization - - return [ - Send( - "verifier", - {"document": doc, "question": qa_state["sub_question"]}, - ) - for doc in qa_state["sub_question_base_retrieval_docs"] - ] - - -def sub_continue_to_retrieval(qa_state: QAState) -> list[Send]: - # Routes re-written queries to the (parallel) retrieval steps - # Notice the 'Send()' API that takes care of the parallelization - return [ - Send("sub_custom_retrieve", {"rewritten_query": query}) - for query in qa_state["sub_question_rewritten_queries"] - ] diff --git a/backend/danswer/agent_search/subgraph/graph_builder.py b/backend/danswer/agent_search/subgraph/graph_builder.py deleted file mode 100644 index 94a5e976dbe..00000000000 --- a/backend/danswer/agent_search/subgraph/graph_builder.py +++ /dev/null @@ -1,59 +0,0 @@ -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph - -from danswer.agent_search.subgraph.edges import sub_continue_to_retrieval -from danswer.agent_search.subgraph.edges import sub_continue_to_verifier -from danswer.agent_search.subgraph.nodes import sub_combine_retrieved_docs -from danswer.agent_search.subgraph.nodes import sub_custom_retrieve -from danswer.agent_search.subgraph.nodes import sub_final_format -from danswer.agent_search.subgraph.nodes import sub_generate -from danswer.agent_search.subgraph.nodes import sub_qa_check -from danswer.agent_search.subgraph.nodes import sub_rewrite -from danswer.agent_search.subgraph.nodes import verifier -from danswer.agent_search.subgraph.states import SubQAOutputState -from danswer.agent_search.subgraph.states import SubQAState - - -def build_subgraph() -> StateGraph: - sub_answers = StateGraph(SubQAState, output=SubQAOutputState) - sub_answers.add_node(node="sub_rewrite", action=sub_rewrite) - sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve) - sub_answers.add_node( - node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs - ) - sub_answers.add_node(node="verifier", action=verifier) - sub_answers.add_node(node="sub_generate", action=sub_generate) - sub_answers.add_node(node="sub_qa_check", action=sub_qa_check) - sub_answers.add_node(node="sub_final_format", action=sub_final_format) - - sub_answers.add_edge(START, "sub_rewrite") - - sub_answers.add_conditional_edges( - "sub_rewrite", sub_continue_to_retrieval, ["sub_custom_retrieve"] - ) - - sub_answers.add_edge("sub_custom_retrieve", "sub_combine_retrieved_docs") - - sub_answers.add_conditional_edges( - "sub_combine_retrieved_docs", sub_continue_to_verifier, ["verifier"] - ) - - sub_answers.add_edge("verifier", "sub_generate") - - sub_answers.add_edge("sub_generate", "sub_qa_check") - - sub_answers.add_edge("sub_qa_check", "sub_final_format") - - sub_answers.add_edge("sub_final_format", END) - sub_answers_graph = sub_answers.compile() - return sub_answers_graph - - -if __name__ == "__main__": - # TODO: add the actual question - inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"} - sub_answers_graph = build_subgraph() - output = sub_answers_graph.invoke(inputs) - print("\nOUTPUT:") - print(output) diff --git a/backend/danswer/agent_search/subgraph/nodes.py b/backend/danswer/agent_search/subgraph/nodes.py deleted file mode 100644 index 27cf0243859..00000000000 --- a/backend/danswer/agent_search/subgraph/nodes.py +++ /dev/null @@ -1,316 +0,0 @@ -import datetime -import json -from typing import Any -from typing import Dict -from typing import Literal - -from langchain_core.messages import HumanMessage -from pydantic import BaseModel - -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.subgraph.states import SubQAOutputState -from danswer.agent_search.subgraph.states import SubQAState -from danswer.chat.models import DanswerContext -from danswer.llm.interfaces import LLM - - -class BinaryDecision(BaseModel): - decision: Literal["yes", "no"] - - -# unused at this point. Kept from tutorial. But here the agent makes a routing decision -# not that much of an agent if the routing is static... -def sub_rewrite(sub_qa_state: SubQAState) -> Dict[str, Any]: - """ - Transform the initial question into more suitable search queries. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - - print("---SUB TRANSFORM QUERY---") - - start_time = datetime.datetime.now() - - # messages = state["base_answer_messages"] - question = sub_qa_state["sub_question"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), - ) - ] - print(msg) - - # Get the rewritten queries in a defined format - ##response = model.with_structured_output(RewrittenQuery).invoke(msg) - ##rewritten_query = response.base_answer_rewritten_query - - rewritten_queries = ["music hard to listen to", "Music that is not fun or pleasant"] - - end_time = datetime.datetime.now() - return { - "sub_question_rewritten_queries": rewritten_queries, - "log_messages": f"{str(start_time)} - {str(end_time)}: sub - rewrite", - } - - -# dummy node to report on state if needed -def sub_custom_retrieve(retriever_state: RetrieverState) -> Dict[str, Any]: - """ - Retrieve documents - - Args: - state (dict): The current graph state - - Returns: - state (dict): New key added to state, documents, that contains retrieved documents - """ - print("---RETRIEVE SUB---") - - start_time = datetime.datetime.now() - - retriever_state["rewritten_query"] - - # query = state["rewritten_query"] - - # Retrieval - # TODO: add the actual retrieval, probably from search_tool.run() - documents: list[DanswerContext] = [] - - end_time = datetime.datetime.now() - return { - "sub_question_base_retrieval_docs": documents, - "log_messages": f"{str(start_time)} - {str(end_time)}: sub - custom_retrieve", - } - - -def sub_combine_retrieved_docs(sub_qa_state: SubQAState) -> Dict[str, Any]: - """ - Dedupe the retrieved docs. - """ - start_time = datetime.datetime.now() - - sub_question_base_retrieval_docs = sub_qa_state["sub_question_base_retrieval_docs"] - - print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") - dedupe_docs = [] - for base_retrieval_doc in sub_question_base_retrieval_docs: - if base_retrieval_doc not in dedupe_docs: - dedupe_docs.append(base_retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - end_time = datetime.datetime.now() - return { - "sub_question_deduped_retrieval_docs": dedupe_docs, - "log_messages": f"{str(start_time)} - {str(end_time)}: base - combine_retrieved_docs (dedupe)", - } - - -def verifier(verifier_state: VerifierState) -> Dict[str, Any]: - """ - Check whether the document is relevant for the original user question - - Args: - state (VerifierState): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---VERIFY QUTPUT---") - start_time = datetime.datetime.now() - - question = verifier_state["original_question"] - document_content = verifier_state["document"].content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - # Grader - llm: LLM = verifier_state["llm"] - tools: list[dict] = verifier_state["tools"] - response = list( - llm.stream( - prompt=msg, - tools=tools, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - formatted_response: BinaryDecision = json.loads(response[0].content) - verdict = formatted_response.decision - - print(f"Verdict: {verdict}") - - end_time = datetime.datetime.now() - if verdict == "yes": - end_time = datetime.datetime.now() - return { - "sub_question_verified_retrieval_docs": [verifier_state["document"]], - "log_messages": f"{str(start_time)} - {str(end_time)}: base - verifier: yes", - } - else: - end_time = datetime.datetime.now() - return { - "sub_question_verified_retrieval_docs": [], - "log_messages": f"{str(start_time)} - {str(end_time)}: base - verifier: no", - } - - -def sub_generate(sub_qa_state: SubQAState) -> Dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---GENERATE---") - start_time = datetime.datetime.now() - - question = sub_qa_state["sub_question"] - docs = sub_qa_state["sub_question_verified_retrieval_docs"] - - print(f"Number of verified retrieval docs: {docs}") - - # LLM - llm: LLM = sub_qa_state["llm"] - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) - ) - ] - - # Grader - llm: LLM = sub_qa_state["llm"] - tools: list[dict] = sub_qa_state["tools"] - response = list( - llm.stream( - prompt=msg, - tools=tools, - structured_response_format=None, - ) - ) - - answer = response[0].content - - end_time = datetime.datetime.now() - return { - "sub_question_answer": answer, - "log_messages": f"{str(start_time)} - {str(end_time)}: base - generate", - } - - -def sub_base_check(sub_qa_state: SubQAState) -> Dict[str, Any]: - """ - Check whether the final output satisfies the original user question - - Args: - state (messages): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---CHECK QUTPUT---") - start_time = datetime.datetime.now() - - base_answer = sub_qa_state["core_answer_base_answer"] - - question = sub_qa_state["original_question"] - - BASE_CHECK_MESSAGE = [ - HumanMessage( - content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) - ) - ] - - llm: LLM = sub_qa_state["llm"] - tools: list[dict] = sub_qa_state["tools"] - response = list( - llm.stream( - prompt=BASE_CHECK_MESSAGE, - tools=tools, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - formatted_response: BinaryDecision = json.loads(response[0].content) - verdict = formatted_response.decision - - print(f"Verdict: {verdict}") - - end_time = datetime.datetime.now() - return { - "base_answer": base_answer, - "log_messages": f"{str(start_time)} - {str(end_time)}: core - base_check", - } - - -def sub_final_format(sub_qa_state: SubQAState) -> SubQAOutputState: - """ - Create the final output for the QA subgraph - """ - - print("---BASE FINAL FORMAT---") - datetime.datetime.now() - - return { - "sub_qas": [ - { - "sub_question": sub_qa_state["sub_question"], - "sub_answer": sub_qa_state["sub_question_answer"], - "sub_answer_check": sub_qa_state["sub_question_answer_check"], - } - ], - "log_messages": sub_qa_state["log_messages"], - } - - -def sub_qa_check(sub_qa_state: SubQAState) -> Dict[str, str]: - """ - Check if the sub-question answer is satisfactory. - - Args: - state: The current SubQAState containing the sub-question and its answer - - Returns: - Dict containing the check result and log message - """ - end_time = datetime.datetime.now() - - q = sub_qa_state["sub_question"] - a = sub_qa_state["sub_question_answer"] - - BASE_CHECK_MESSAGE = [ - HumanMessage(content=BASE_CHECK_PROMPT.format(question=q, base_answer=a)) - ] - - model: LLM = sub_qa_state["llm"] - response = model.invoke(BASE_CHECK_MESSAGE) - - start_time = datetime.datetime.now() - - return { - "sub_question_answer_check": response.content.lower(), - "base_answer_messages": f"{str(start_time)} - {str(end_time)}: base - qa_check", - } From 56052c5b4b947c05de7a078ad63d72b583606c0b Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Sat, 7 Dec 2024 06:09:57 -0800 Subject: [PATCH 04/10] imports --- backend/danswer/agent_search/primary_graph/nodes.py | 6 +++--- backend/danswer/agent_search/research_qa_sub_graph/nodes.py | 2 +- .../{base_qa_sub_graph => research_qa_sub_graph}/prompts.py | 0 3 files changed, 4 insertions(+), 4 deletions(-) rename backend/danswer/agent_search/{base_qa_sub_graph => research_qa_sub_graph}/prompts.py (100%) diff --git a/backend/danswer/agent_search/primary_graph/nodes.py b/backend/danswer/agent_search/primary_graph/nodes.py index 45fc14358b0..5252e60069b 100644 --- a/backend/danswer/agent_search/primary_graph/nodes.py +++ b/backend/danswer/agent_search/primary_graph/nodes.py @@ -6,15 +6,15 @@ from langchain_core.messages import HumanMessage +from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT +from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT +from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT from danswer.agent_search.primary_graph.states import QAState from danswer.agent_search.primary_graph.states import RetrieverState from danswer.agent_search.primary_graph.states import VerifierState from danswer.agent_search.shared_graph_utils.models import BinaryDecision from danswer.agent_search.shared_graph_utils.models import RewrittenQueries from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string diff --git a/backend/danswer/agent_search/research_qa_sub_graph/nodes.py b/backend/danswer/agent_search/research_qa_sub_graph/nodes.py index 85401a5c8fa..2c982cd7143 100644 --- a/backend/danswer/agent_search/research_qa_sub_graph/nodes.py +++ b/backend/danswer/agent_search/research_qa_sub_graph/nodes.py @@ -7,12 +7,12 @@ from danswer.agent_search.base_qa_sub_graph.states import BaseQAState from danswer.agent_search.primary_graph.states import RetrieverState from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.research_qa_sub_graph.prompts import SUB_CHECK_PROMPT from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState from danswer.agent_search.shared_graph_utils.models import BinaryDecision from danswer.agent_search.shared_graph_utils.models import RewrittenQueries from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT from danswer.agent_search.shared_graph_utils.utils import format_docs from danswer.agent_search.shared_graph_utils.utils import generate_log_message diff --git a/backend/danswer/agent_search/base_qa_sub_graph/prompts.py b/backend/danswer/agent_search/research_qa_sub_graph/prompts.py similarity index 100% rename from backend/danswer/agent_search/base_qa_sub_graph/prompts.py rename to backend/danswer/agent_search/research_qa_sub_graph/prompts.py From 091cb136c41e35b96499b621eb487c29723e8e34 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Sat, 7 Dec 2024 12:25:54 -0800 Subject: [PATCH 05/10] got core qa graph working --- .../agent_search/base_qa_sub_graph/nodes.py | 306 ----------- .../edges.py | 9 +- .../graph_builder.py | 59 ++- .../core_qa_graph/nodes/__init__.py | 0 .../nodes/combine_retrieved_docs.py | 35 ++ .../core_qa_graph/nodes/custom_retrieve.py | 51 ++ .../agent_search/core_qa_graph/nodes/dummy.py | 24 + .../core_qa_graph/nodes/final_format.py | 22 + .../core_qa_graph/nodes/generate.py | 55 ++ .../core_qa_graph/nodes/qa_check.py | 51 ++ .../core_qa_graph/nodes/rewrite.py | 62 +++ .../core_qa_graph/nodes/verifier.py | 64 +++ .../states.py | 32 +- .../edges.py | 2 +- .../graph_builder.py | 40 +- .../deep_qa_graph/nodes/__init__.py | 0 .../nodes/combine_retrieved_docs.py | 31 ++ .../deep_qa_graph/nodes/custom_retrieve.py | 33 ++ .../agent_search/deep_qa_graph/nodes/dummy.py | 21 + .../deep_qa_graph/nodes/final_format.py | 31 ++ .../deep_qa_graph/nodes/generate.py | 56 ++ .../deep_qa_graph/nodes/qa_check.py | 57 ++ .../deep_qa_graph/nodes/rewrite.py | 64 +++ .../deep_qa_graph/nodes/verifier.py | 59 +++ .../prompts.py | 0 .../states.py | 22 +- .../agent_search/primary_graph/edges.py | 4 +- .../primary_graph/graph_builder.py | 64 ++- .../agent_search/primary_graph/nodes.py | 492 ------------------ .../primary_graph/nodes/__init__.py | 0 .../primary_graph/nodes/base_wait.py | 27 + .../nodes/combine_retrieved_docs.py | 36 ++ .../primary_graph/nodes/custom_retrieve.py | 52 ++ .../primary_graph/nodes/decompose.py | 78 +++ .../nodes/deep_answer_generation.py | 61 +++ .../primary_graph/nodes/dummy_start.py | 11 + .../nodes/entity_term_extraction.py | 51 ++ .../primary_graph/nodes/final_stuff.py | 85 +++ .../primary_graph/nodes/generate.py | 52 ++ .../primary_graph/nodes/generate_initial.py | 72 +++ .../primary_graph/nodes/main_decomp_base.py | 64 +++ .../primary_graph/nodes/rewrite.py | 55 ++ .../nodes/sub_qa_level_aggregator.py | 39 ++ .../primary_graph/nodes/sub_qa_manager.py | 28 + .../primary_graph/nodes/verifier.py | 59 +++ .../agent_search/primary_graph/states.py | 26 +- .../research_qa_sub_graph/nodes.py | 308 ----------- .../shared_graph_utils/prompts.py | 10 +- .../agent_search/shared_graph_utils/utils.py | 6 +- backend/danswer/tools/message.py | 3 + backend/requirements/default.txt | 2 - 51 files changed, 1661 insertions(+), 1210 deletions(-) delete mode 100644 backend/danswer/agent_search/base_qa_sub_graph/nodes.py rename backend/danswer/agent_search/{base_qa_sub_graph => core_qa_graph}/edges.py (79%) rename backend/danswer/agent_search/{base_qa_sub_graph => core_qa_graph}/graph_builder.py (53%) create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/__init__.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/dummy.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/final_format.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/generate.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py create mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/verifier.py rename backend/danswer/agent_search/{base_qa_sub_graph => core_qa_graph}/states.py (73%) rename backend/danswer/agent_search/{research_qa_sub_graph => deep_qa_graph}/edges.py (95%) rename backend/danswer/agent_search/{research_qa_sub_graph => deep_qa_graph}/graph_builder.py (61%) create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/__init__.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/generate.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py create mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py rename backend/danswer/agent_search/{research_qa_sub_graph => deep_qa_graph}/prompts.py (100%) rename backend/danswer/agent_search/{research_qa_sub_graph => deep_qa_graph}/states.py (74%) delete mode 100644 backend/danswer/agent_search/primary_graph/nodes.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/__init__.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/base_wait.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/decompose.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/dummy_start.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/entity_term_extraction.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/final_stuff.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/generate.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/generate_initial.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/rewrite.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/sub_qa_level_aggregator.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py create mode 100644 backend/danswer/agent_search/primary_graph/nodes/verifier.py delete mode 100644 backend/danswer/agent_search/research_qa_sub_graph/nodes.py diff --git a/backend/danswer/agent_search/base_qa_sub_graph/nodes.py b/backend/danswer/agent_search/base_qa_sub_graph/nodes.py deleted file mode 100644 index e616ef906eb..00000000000 --- a/backend/danswer/agent_search/base_qa_sub_graph/nodes.py +++ /dev/null @@ -1,306 +0,0 @@ -import datetime -import json -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.base_qa_sub_graph.states import BaseQAState -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.shared_graph_utils.models import BinaryDecision -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.chat.models import DanswerContext -from danswer.llm.interfaces import LLM - - -# unused at this point. Kept from tutorial. But here the agent makes a routing decision -# not that much of an agent if the routing is static... -def sub_rewrite(state: BaseQAState) -> dict[str, Any]: - """ - Transform the initial question into more suitable search queries. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - - print("---SUB TRANSFORM QUERY---") - - node_start_time = datetime.datetime.now() - - # messages = state["base_answer_messages"] - question = state["sub_question_str"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), - ) - ] - fast_llm: LLM = state["fast_llm"] - llm_response = list( - fast_llm.stream( - prompt=msg, - structured_response_format=RewrittenQueries.model_json_schema(), - ) - ) - - # Get the rewritten queries in a defined format - rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) - - print(f"rewritten_queries: {rewritten_queries}") - - rewritten_queries = RewrittenQueries( - rewritten_queries=[ - "music hard to listen to", - "Music that is not fun or pleasant", - ] - ) - - print(f"hardcoded rewritten_queries: {rewritten_queries}") - - return { - "sub_question_rewritten_queries": rewritten_queries, - "log_messages": generate_log_message( - message="sub - rewrite", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -# dummy node to report on state if needed -def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: - """ - Retrieve documents - - Args: - state (dict): The current graph state - - Returns: - state (dict): New key added to state, documents, that contains retrieved documents - """ - print("---RETRIEVE SUB---") - - node_start_time = datetime.datetime.now() - - # rewritten_query = state["rewritten_query"] - - # Retrieval - # TODO: add the actual retrieval, probably from search_tool.run() - documents: list[DanswerContext] = [] - - return { - "sub_question_base_retrieval_docs": documents, - "log_messages": generate_log_message( - message="sub - custom_retrieve", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]: - """ - Dedupe the retrieved docs. - """ - node_start_time = datetime.datetime.now() - - sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] - - print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") - dedupe_docs = [] - for base_retrieval_doc in sub_question_base_retrieval_docs: - if base_retrieval_doc not in dedupe_docs: - dedupe_docs.append(base_retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return { - "sub_question_deduped_retrieval_docs": dedupe_docs, - "log_messages": generate_log_message( - message="sub - combine_retrieved_docs (dedupe)", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_verifier(state: VerifierState) -> dict[str, Any]: - """ - Check whether the document is relevant for the original user question - - Args: - state (VerifierState): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---VERIFY QUTPUT---") - node_start_time = datetime.datetime.now() - - question = state["question"] - document_content = state["document"].content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - # Grader - llm: LLM = state["primary_llm"] - response = list( - llm.stream( - prompt=msg, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - raw_response = json.loads(response[0].pretty_repr()) - formatted_response = BinaryDecision.model_validate(raw_response) - - print(f"Verdict: {formatted_response.decision}") - - return { - "sub_question_verified_retrieval_docs": [state["document"]] - if formatted_response.decision == "yes" - else [], - "log_messages": generate_log_message( - message=f"sub - verifier: {formatted_response.decision}", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_generate(state: BaseQAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---GENERATE---") - start_time = datetime.datetime.now() - - question = state["original_question"] - docs = state["sub_question_verified_retrieval_docs"] - - print(f"Number of verified retrieval docs: {len(docs)}") - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) - ) - ] - - # Grader - llm: LLM = state["primary_llm"] - response = list( - llm.stream( - prompt=msg, - structured_response_format=None, - ) - ) - - answer = response[0].pretty_repr() - return { - "sub_question_answer": answer, - "log_messages": generate_log_message( - message="base - generate", - node_start_time=start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_final_format(state: BaseQAState) -> dict[str, Any]: - """ - Create the final output for the QA subgraph - """ - - print("---BASE FINAL FORMAT---") - datetime.datetime.now() - - return { - "sub_qas": [ - { - "sub_question": state["sub_question_str"], - "sub_answer": state["sub_question_answer"], - "sub_answer_check": state["sub_question_answer_check"], - } - ], - "log_messages": state["log_messages"], - } - - -def sub_qa_check(state: BaseQAState) -> dict[str, Any]: - """ - Check if the sub-question answer is satisfactory. - - Args: - state: The current SubQAState containing the sub-question and its answer - - Returns: - dict containing the check result and log message - """ - - msg = [ - HumanMessage( - content=BASE_CHECK_PROMPT.format( - question=state["sub_question_str"], - base_answer=state["sub_question_answer"], - ) - ) - ] - - model: LLM = state["primary_llm"] - response = list( - model.stream( - prompt=msg, - structured_response_format=None, - ) - ) - - start_time = datetime.datetime.now() - - return { - "sub_question_answer_check": response[0].pretty_repr().lower(), - "base_answer_messages": generate_log_message( - message="sub - qa_check", - node_start_time=start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_dummy(state: BaseQAState) -> dict[str, Any]: - """ - Dummy step - """ - - print("---Sub Dummy---") - - node_start_time = datetime.datetime.now() - - return { - "log_messages": generate_log_message( - message="sub - dummy", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/base_qa_sub_graph/edges.py b/backend/danswer/agent_search/core_qa_graph/edges.py similarity index 79% rename from backend/danswer/agent_search/base_qa_sub_graph/edges.py rename to backend/danswer/agent_search/core_qa_graph/edges.py index 08118446f38..0d0c2b3a50d 100644 --- a/backend/danswer/agent_search/base_qa_sub_graph/edges.py +++ b/backend/danswer/agent_search/core_qa_graph/edges.py @@ -3,7 +3,7 @@ from langgraph.types import Send -from danswer.agent_search.base_qa_sub_graph.states import BaseQAState +from danswer.agent_search.core_qa_graph.states import BaseQAState from danswer.agent_search.primary_graph.states import RetrieverState from danswer.agent_search.primary_graph.states import VerifierState @@ -18,8 +18,6 @@ def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashabl VerifierState( document=doc, question=state["sub_question_str"], - fast_llm=state["fast_llm"], - primary_llm=state["primary_llm"], graph_start_time=state["graph_start_time"], ), ) @@ -30,15 +28,14 @@ def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashabl def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]: # Routes re-written queries to the (parallel) retrieval steps # Notice the 'Send()' API that takes care of the parallelization + rewritten_queries = state["sub_question_search_queries"].rewritten_queries return [ Send( "sub_custom_retrieve", RetrieverState( rewritten_query=query, - primary_llm=state["primary_llm"], - fast_llm=state["fast_llm"], graph_start_time=state["graph_start_time"], ), ) - for query in state["sub_question_search_queries"] + for query in rewritten_queries ] diff --git a/backend/danswer/agent_search/base_qa_sub_graph/graph_builder.py b/backend/danswer/agent_search/core_qa_graph/graph_builder.py similarity index 53% rename from backend/danswer/agent_search/base_qa_sub_graph/graph_builder.py rename to backend/danswer/agent_search/core_qa_graph/graph_builder.py index 62943930c0f..1031d945cc5 100644 --- a/backend/danswer/agent_search/base_qa_sub_graph/graph_builder.py +++ b/backend/danswer/agent_search/core_qa_graph/graph_builder.py @@ -2,30 +2,36 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from danswer.agent_search.base_qa_sub_graph.edges import sub_continue_to_retrieval -from danswer.agent_search.base_qa_sub_graph.edges import sub_continue_to_verifier -from danswer.agent_search.base_qa_sub_graph.nodes import sub_combine_retrieved_docs -from danswer.agent_search.base_qa_sub_graph.nodes import sub_custom_retrieve -from danswer.agent_search.base_qa_sub_graph.nodes import sub_dummy -from danswer.agent_search.base_qa_sub_graph.nodes import sub_final_format -from danswer.agent_search.base_qa_sub_graph.nodes import sub_generate -from danswer.agent_search.base_qa_sub_graph.nodes import sub_qa_check -from danswer.agent_search.base_qa_sub_graph.nodes import sub_verifier -from danswer.agent_search.base_qa_sub_graph.states import BaseQAOutputState -from danswer.agent_search.base_qa_sub_graph.states import BaseQAState +from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval +from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier +from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import ( + sub_combine_retrieved_docs, +) +from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import ( + sub_custom_retrieve, +) +from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy +from danswer.agent_search.core_qa_graph.nodes.final_format import ( + sub_final_format, +) +from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate +from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check +from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite +from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier +from danswer.agent_search.core_qa_graph.states import BaseQAOutputState +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.core_qa_graph.states import CoreQAInputState -# from danswer.agent_search.base_qa_sub_graph.nodes import sub_rewrite - -def build_base_qa_sub_graph() -> StateGraph: +def build_core_qa_graph() -> StateGraph: sub_answers_initial = StateGraph( state_schema=BaseQAState, output=BaseQAOutputState, ) ### Add nodes ### - # sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite) sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy) + sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite) sub_answers_initial.add_node( node="sub_custom_retrieve", action=sub_custom_retrieve, @@ -52,12 +58,12 @@ def build_base_qa_sub_graph() -> StateGraph: ) ### Add edges ### - sub_answers_initial.add_edge(START, "sub_rewrite") + sub_answers_initial.add_edge(START, "sub_dummy") + sub_answers_initial.add_edge("sub_dummy", "sub_rewrite") sub_answers_initial.add_conditional_edges( source="sub_rewrite", path=sub_continue_to_retrieval, - path_map=["sub_custom_retrieve"], ) sub_answers_initial.add_edge( @@ -95,10 +101,21 @@ def build_base_qa_sub_graph() -> StateGraph: if __name__ == "__main__": - # TODO: add the actual question - inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"} - sub_answers_graph = build_base_qa_sub_graph() + inputs = CoreQAInputState( + original_question="Whose music is kind of hard to easily enjoy?", + sub_question_str="Whose music is kind of hard to easily enjoy?", + ) + sub_answers_graph = build_core_qa_graph() compiled_sub_answers = sub_answers_graph.compile() output = compiled_sub_answers.invoke(inputs) print("\nOUTPUT:") - print(output) + print(output.keys()) + for key, value in output.items(): + if key in [ + "sub_question_answer", + "sub_question_str", + "sub_qas", + "initial_sub_qas", + "sub_question_answer", + ]: + print(f"{key}: {value}") diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/__init__.py b/backend/danswer/agent_search/core_qa_graph/nodes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py new file mode 100644 index 00000000000..64f657363ef --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py @@ -0,0 +1,35 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.context.search.models import InferenceSection + + +def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]: + """ + Dedupe the retrieved docs. + """ + node_start_time = datetime.now() + + sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] + + print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") + dedupe_docs: list[InferenceSection] = [] + for base_retrieval_doc in sub_question_base_retrieval_docs: + if not any( + base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id + for doc in dedupe_docs + ): + dedupe_docs.append(base_retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + return { + "sub_question_deduped_retrieval_docs": dedupe_docs, + "log_messages": generate_log_message( + message="sub - combine_retrieved_docs (dedupe)", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py new file mode 100644 index 00000000000..4793c511b7c --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py @@ -0,0 +1,51 @@ +import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.context.search.models import InferenceSection +from danswer.context.search.models import SearchRequest +from danswer.context.search.pipeline import SearchPipeline +from danswer.db.engine import get_session_context_manager +from danswer.llm.factory import get_default_llms + + +def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE SUB---") + + node_start_time = datetime.datetime.now() + + rewritten_query = state["rewritten_query"] + + # Retrieval + # TODO: add the actual retrieval, probably from search_tool.run() + documents: list[InferenceSection] = [] + llm, fast_llm = get_default_llms() + with get_session_context_manager() as db_session: + documents = SearchPipeline( + search_request=SearchRequest( + query=rewritten_query, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=db_session, + ).reranked_sections + + return { + "sub_question_base_retrieval_docs": documents, + "log_messages": generate_log_message( + message="sub - custom_retrieve", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py b/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py new file mode 100644 index 00000000000..be018334aa9 --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py @@ -0,0 +1,24 @@ +import datetime +from typing import Any + +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_dummy(state: BaseQAState) -> dict[str, Any]: + """ + Dummy step + """ + + print("---Sub Dummy---") + + node_start_time = datetime.datetime.now() + + return { + "graph_start_time": node_start_time, + "log_messages": generate_log_message( + message="sub - dummy", + node_start_time=node_start_time, + graph_start_time=node_start_time, + ), + } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py b/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py new file mode 100644 index 00000000000..8b24cc0a6ed --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py @@ -0,0 +1,22 @@ +from typing import Any + +from danswer.agent_search.core_qa_graph.states import BaseQAState + + +def sub_final_format(state: BaseQAState) -> dict[str, Any]: + """ + Create the final output for the QA subgraph + """ + + print("---BASE FINAL FORMAT---") + + return { + "sub_qas": [ + { + "sub_question": state["sub_question_str"], + "sub_answer": state["sub_question_answer"], + "sub_answer_check": state["sub_question_answer_check"], + } + ], + "log_messages": state["log_messages"], + } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/generate.py b/backend/danswer/agent_search/core_qa_graph/nodes/generate.py new file mode 100644 index 00000000000..ed1c3661be9 --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/generate.py @@ -0,0 +1,55 @@ +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.llm.factory import get_default_llms + + +def sub_generate(state: BaseQAState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---GENERATE---") + node_start_time = datetime.now() + + question = state["original_question"] + docs = state["sub_question_verified_retrieval_docs"] + + print(f"Number of verified retrieval docs: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + ) + ] + + # Grader + _, fast_llm = get_default_llms() + response = list( + fast_llm.stream( + prompt=msg, + # structured_response_format=None, + ) + ) + + answer_str = merge_message_runs(response, chunk_separator="")[0].content + return { + "sub_question_answer": answer_str, + "log_messages": generate_log_message( + message="base - generate", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py b/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py new file mode 100644 index 00000000000..adff7a3d4cd --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py @@ -0,0 +1,51 @@ +import datetime +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.llm.factory import get_default_llms + + +def sub_qa_check(state: BaseQAState) -> dict[str, Any]: + """ + Check if the sub-question answer is satisfactory. + + Args: + state: The current SubQAState containing the sub-question and its answer + + Returns: + dict containing the check result and log message + """ + node_start_time = datetime.datetime.now() + + msg = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format( + question=state["sub_question_str"], + base_answer=state["sub_question_answer"], + ) + ) + ] + + _, fast_llm = get_default_llms() + response = list( + fast_llm.stream( + prompt=msg, + # structured_response_format=None, + ) + ) + + response_str = merge_message_runs(response, chunk_separator="")[0].content + + return { + "sub_question_answer_check": response_str, + "base_answer_messages": generate_log_message( + message="sub - qa_check", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py b/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py new file mode 100644 index 00000000000..e5841efbd46 --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py @@ -0,0 +1,62 @@ +import datetime +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.llm.factory import get_default_llms + + +def sub_rewrite(state: BaseQAState) -> dict[str, Any]: + """ + Transform the initial question into more suitable search queries. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + print("---SUB TRANSFORM QUERY---") + + node_start_time = datetime.datetime.now() + + # messages = state["base_answer_messages"] + question = state["sub_question_str"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + _, fast_llm = get_default_llms() + llm_response_list = list( + fast_llm.stream( + prompt=msg, + # structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()}, + # structured_response_format=RewrittenQueries.model_json_schema(), + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + print(f"llm_response: {llm_response}") + + rewritten_queries = llm_response.split("\n") + + print(f"rewritten_queries: {rewritten_queries}") + + rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries) + + return { + "sub_question_search_queries": rewritten_queries, + "log_messages": generate_log_message( + message="sub - rewrite", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py b/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py new file mode 100644 index 00000000000..8bcaee1a5b7 --- /dev/null +++ b/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py @@ -0,0 +1,64 @@ +import datetime +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.llm.factory import get_default_llms + + +def sub_verifier(state: VerifierState) -> dict[str, Any]: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---VERIFY QUTPUT---") + node_start_time = datetime.datetime.now() + + question = state["question"] + document_content = state["document"].combined_content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + # Grader + llm, fast_llm = get_default_llms() + response = list( + llm.stream( + prompt=msg, + # structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + response_string = merge_message_runs(response, chunk_separator="")[0].content + # Convert string response to proper dictionary format + decision_dict = {"decision": response_string.lower()} + formatted_response = BinaryDecision.model_validate(decision_dict) + + print(f"Verdict: {formatted_response.decision}") + + return { + "sub_question_verified_retrieval_docs": [state["document"]] + if formatted_response.decision == "yes" + else [], + "log_messages": generate_log_message( + message=f"sub - verifier: {formatted_response.decision}", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/base_qa_sub_graph/states.py b/backend/danswer/agent_search/core_qa_graph/states.py similarity index 73% rename from backend/danswer/agent_search/base_qa_sub_graph/states.py rename to backend/danswer/agent_search/core_qa_graph/states.py index f14533035d8..5c53df06275 100644 --- a/backend/danswer/agent_search/base_qa_sub_graph/states.py +++ b/backend/danswer/agent_search/core_qa_graph/states.py @@ -7,7 +7,8 @@ from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messages -from danswer.chat.models import DanswerContext +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries +from danswer.context.search.models import InferenceSection from danswer.llm.interfaces import LLM @@ -18,10 +19,15 @@ class SubQuestionRetrieverState(TypedDict): class SubQuestionVerifierState(TypedDict): # The state for the parallel verification step. Each node execution need to see only one question/doc pair - sub_question_document: DanswerContext + sub_question_document: InferenceSection sub_question: str +class CoreQAInputState(TypedDict): + sub_question_str: str + original_question: str + + class BaseQAState(TypedDict): # The 'core SubQuestion' state. original_question: str @@ -30,17 +36,19 @@ class BaseQAState(TypedDict): sub_query_start_time: datetime sub_question_rewritten_queries: list[str] sub_question_str: str - sub_question_search_queries: list[str] + sub_question_search_queries: RewrittenQueries sub_question_nr: int - sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_base_retrieval_docs: Annotated[ + Sequence[InferenceSection], operator.add + ] sub_question_deduped_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_verified_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_reranked_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] sub_question_answer: str @@ -63,15 +71,17 @@ class BaseQAOutputState(TypedDict): sub_qas: Annotated[Sequence[dict], operator.add] # Answers sent back to core initial_sub_qas: Annotated[Sequence[dict], operator.add] - sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_base_retrieval_docs: Annotated[ + Sequence[InferenceSection], operator.add + ] sub_question_deduped_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_verified_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_reranked_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] sub_question_answer: str diff --git a/backend/danswer/agent_search/research_qa_sub_graph/edges.py b/backend/danswer/agent_search/deep_qa_graph/edges.py similarity index 95% rename from backend/danswer/agent_search/research_qa_sub_graph/edges.py rename to backend/danswer/agent_search/deep_qa_graph/edges.py index 48afc58206e..980af9159b9 100644 --- a/backend/danswer/agent_search/research_qa_sub_graph/edges.py +++ b/backend/danswer/agent_search/deep_qa_graph/edges.py @@ -3,9 +3,9 @@ from langgraph.types import Send +from danswer.agent_search.deep_qa_graph.states import ResearchQAState from danswer.agent_search.primary_graph.states import RetrieverState from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]: diff --git a/backend/danswer/agent_search/research_qa_sub_graph/graph_builder.py b/backend/danswer/agent_search/deep_qa_graph/graph_builder.py similarity index 61% rename from backend/danswer/agent_search/research_qa_sub_graph/graph_builder.py rename to backend/danswer/agent_search/deep_qa_graph/graph_builder.py index 40ea5580ad8..90c7aebeb0f 100644 --- a/backend/danswer/agent_search/research_qa_sub_graph/graph_builder.py +++ b/backend/danswer/agent_search/deep_qa_graph/graph_builder.py @@ -2,22 +2,22 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from danswer.agent_search.research_qa_sub_graph.edges import sub_continue_to_retrieval -from danswer.agent_search.research_qa_sub_graph.edges import sub_continue_to_verifier -from danswer.agent_search.research_qa_sub_graph.nodes import sub_combine_retrieved_docs -from danswer.agent_search.research_qa_sub_graph.nodes import sub_custom_retrieve -from danswer.agent_search.research_qa_sub_graph.nodes import sub_dummy -from danswer.agent_search.research_qa_sub_graph.nodes import sub_final_format -from danswer.agent_search.research_qa_sub_graph.nodes import sub_generate -from danswer.agent_search.research_qa_sub_graph.nodes import sub_qa_check -from danswer.agent_search.research_qa_sub_graph.nodes import sub_verifier -from danswer.agent_search.research_qa_sub_graph.states import ResearchQAOutputState -from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState - -# from danswer.agent_search.research_qa_sub_graph.nodes import sub_rewrite - - -def build_base_qa_sub_graph() -> StateGraph: +from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval +from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier +from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import ( + sub_combine_retrieved_docs, +) +from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve +from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy +from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format +from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate +from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check +from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier +from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState +from danswer.agent_search.deep_qa_graph.states import ResearchQAState + + +def build_deep_qa_graph() -> StateGraph: # Define the nodes we will cycle between sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState) @@ -48,26 +48,32 @@ def build_base_qa_sub_graph() -> StateGraph: ### Add Edges ### + # Generate multiple sub-questions sub_answers.add_edge(start_key=START, end_key="sub_rewrite") + # For each sub-question, perform a retrieval in parallel sub_answers.add_conditional_edges( source="sub_rewrite", path=sub_continue_to_retrieval, path_map=["sub_custom_retrieve"], ) + # Combine the retrieved docs for each sub-question from the parallel retrievals sub_answers.add_edge( start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs" ) + # Go over all of the combined retrieved docs and verify them against the original question sub_answers.add_conditional_edges( source="sub_combine_retrieved_docs", path=sub_continue_to_verifier, path_map=["sub_verifier"], ) + # Generate an answer for each verified retrieved doc sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate") + # Check the quality of the answer sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check") sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format") @@ -80,7 +86,7 @@ def build_base_qa_sub_graph() -> StateGraph: if __name__ == "__main__": # TODO: add the actual question inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"} - sub_answers_graph = build_base_qa_sub_graph() + sub_answers_graph = build_deep_qa_graph() compiled_sub_answers = sub_answers_graph.compile() output = compiled_sub_answers.invoke(inputs) print("\nOUTPUT:") diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/__init__.py b/backend/danswer/agent_search/deep_qa_graph/nodes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py new file mode 100644 index 00000000000..542c823ae33 --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py @@ -0,0 +1,31 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.deep_qa_graph.states import ResearchQAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]: + """ + Dedupe the retrieved docs. + """ + node_start_time = datetime.now() + + sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] + + print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") + dedupe_docs = [] + for base_retrieval_doc in sub_question_base_retrieval_docs: + if base_retrieval_doc not in dedupe_docs: + dedupe_docs.append(base_retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + return { + "sub_question_deduped_retrieval_docs": dedupe_docs, + "log_messages": generate_log_message( + message="sub - combine_retrieved_docs (dedupe)", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py new file mode 100644 index 00000000000..b041de35768 --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.context.search.models import InferenceSection + + +def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE SUB---") + node_start_time = datetime.now() + + # Retrieval + # TODO: add the actual retrieval, probably from search_tool.run() + documents: list[InferenceSection] = [] + + return { + "sub_question_base_retrieval_docs": documents, + "log_messages": generate_log_message( + message="sub - custom_retrieve", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py b/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py new file mode 100644 index 00000000000..976c81744de --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py @@ -0,0 +1,21 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_dummy(state: BaseQAState) -> dict[str, Any]: + """ + Dummy step + """ + + print("---Sub Dummy---") + + return { + "log_messages": generate_log_message( + message="sub - dummy", + node_start_time=datetime.now(), + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py b/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py new file mode 100644 index 00000000000..d3a8706f9e1 --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py @@ -0,0 +1,31 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.deep_qa_graph.states import ResearchQAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_final_format(state: ResearchQAState) -> dict[str, Any]: + """ + Create the final output for the QA subgraph + """ + + print("---SUB FINAL FORMAT---") + node_start_time = datetime.now() + + return { + # TODO: Type this + "sub_qas": [ + { + "sub_question": state["sub_question"], + "sub_answer": state["sub_question_answer"], + "sub_question_nr": state["sub_question_nr"], + "sub_answer_check": state["sub_question_answer_check"], + } + ], + "log_messages": generate_log_message( + message="sub - final format", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py b/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py new file mode 100644 index 00000000000..dbd478fb796 --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py @@ -0,0 +1,56 @@ +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.deep_qa_graph.states import ResearchQAState +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_generate(state: ResearchQAState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---SUB GENERATE---") + node_start_time = datetime.now() + + question = state["sub_question"] + docs = state["sub_question_verified_retrieval_docs"] + + print(f"Number of verified retrieval docs for sub-question: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + ) + ] + + # Grader + if len(docs) > 0: + model = state["fast_llm"] + response = list( + model.stream( + prompt=msg, + ) + ) + response_str = merge_message_runs(response, chunk_separator="")[0].content + else: + response_str = "" + + return { + "sub_question_answer": response_str, + "log_messages": generate_log_message( + message="sub - generate", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py b/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py new file mode 100644 index 00000000000..5dd77af213c --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py @@ -0,0 +1,57 @@ +import json +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT +from danswer.agent_search.deep_qa_graph.states import ResearchQAState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_qa_check(state: ResearchQAState) -> dict[str, Any]: + """ + Check whether the final output satisfies the original user question + + Args: + state (messages): The current state + + Returns: + dict: The updated state with the final decision + """ + + print("---CHECK SUB QUTPUT---") + node_start_time = datetime.now() + + sub_answer = state["sub_question_answer"] + sub_question = state["sub_question"] + + msg = [ + HumanMessage( + content=SUB_CHECK_PROMPT.format( + sub_question=sub_question, sub_answer=sub_answer + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = list( + model.stream( + prompt=msg, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + raw_response = json.loads(response[0].pretty_repr()) + formatted_response = BinaryDecision.model_validate(raw_response) + + return { + "sub_question_answer_check": formatted_response.decision, + "log_messages": generate_log_message( + message=f"sub - qa check: {formatted_response.decision}", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py b/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py new file mode 100644 index 00000000000..f7cb836842a --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py @@ -0,0 +1,64 @@ +import json +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.deep_qa_graph.states import ResearchQAState +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.llm.interfaces import LLM + + +def sub_rewrite(state: ResearchQAState) -> dict[str, Any]: + """ + Transform the initial question into more suitable search queries. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + print("---SUB TRANSFORM QUERY---") + node_start_time = datetime.now() + + question = state["sub_question"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + fast_llm: LLM = state["fast_llm"] + llm_response = list( + fast_llm.stream( + prompt=msg, + structured_response_format=RewrittenQueries.model_json_schema(), + ) + ) + + # Get the rewritten queries in a defined format + rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) + + print(f"rewritten_queries: {rewritten_queries}") + + rewritten_queries = RewrittenQueries( + rewritten_queries=[ + "music hard to listen to", + "Music that is not fun or pleasant", + ] + ) + + print(f"hardcoded rewritten_queries: {rewritten_queries}") + + return { + "sub_question_rewritten_queries": rewritten_queries, + "log_messages": generate_log_message( + message="sub - rewrite", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py b/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py new file mode 100644 index 00000000000..49bfcb62d17 --- /dev/null +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py @@ -0,0 +1,59 @@ +import json +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_verifier(state: VerifierState) -> dict[str, Any]: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---SUB VERIFY QUTPUT---") + node_start_time = datetime.now() + + question = state["question"] + document_content = state["document"].combined_content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = list( + model.stream( + prompt=msg, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + raw_response = json.loads(response[0].pretty_repr()) + formatted_response = BinaryDecision.model_validate(raw_response) + + return { + "deduped_retrieval_docs": [state["document"]] + if formatted_response.decision == "yes" + else [], + "log_messages": generate_log_message( + message=f"core - verifier: {formatted_response.decision}", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/research_qa_sub_graph/prompts.py b/backend/danswer/agent_search/deep_qa_graph/prompts.py similarity index 100% rename from backend/danswer/agent_search/research_qa_sub_graph/prompts.py rename to backend/danswer/agent_search/deep_qa_graph/prompts.py diff --git a/backend/danswer/agent_search/research_qa_sub_graph/states.py b/backend/danswer/agent_search/deep_qa_graph/states.py similarity index 74% rename from backend/danswer/agent_search/research_qa_sub_graph/states.py rename to backend/danswer/agent_search/deep_qa_graph/states.py index abc0d3b04f1..2492f4b4ee5 100644 --- a/backend/danswer/agent_search/research_qa_sub_graph/states.py +++ b/backend/danswer/agent_search/deep_qa_graph/states.py @@ -7,7 +7,7 @@ from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messages -from danswer.chat.models import DanswerContext +from danswer.context.search.models import InferenceSection from danswer.llm.interfaces import LLM @@ -18,15 +18,17 @@ class ResearchQAState(TypedDict): sub_question_rewritten_queries: list[str] sub_question: str sub_question_nr: int - sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_base_retrieval_docs: Annotated[ + Sequence[InferenceSection], operator.add + ] sub_question_deduped_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_verified_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_reranked_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] sub_question_answer: str @@ -44,15 +46,17 @@ class ResearchQAOutputState(TypedDict): sub_question_nr: int # Answers sent back to core sub_qas: Annotated[Sequence[dict], operator.add] - sub_question_base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + sub_question_base_retrieval_docs: Annotated[ + Sequence[InferenceSection], operator.add + ] sub_question_deduped_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_verified_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_reranked_retrieval_docs: Annotated[ - Sequence[DanswerContext], operator.add + Sequence[InferenceSection], operator.add ] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] sub_question_answer: str diff --git a/backend/danswer/agent_search/primary_graph/edges.py b/backend/danswer/agent_search/primary_graph/edges.py index a21709ec3b3..495ae807f79 100644 --- a/backend/danswer/agent_search/primary_graph/edges.py +++ b/backend/danswer/agent_search/primary_graph/edges.py @@ -4,9 +4,9 @@ from langchain_core.messages import HumanMessage from langgraph.types import Send -from danswer.agent_search.base_qa_sub_graph.states import BaseQAState +from danswer.agent_search.core_qa_graph.states import BaseQAState +from danswer.agent_search.deep_qa_graph.states import ResearchQAState from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT diff --git a/backend/danswer/agent_search/primary_graph/graph_builder.py b/backend/danswer/agent_search/primary_graph/graph_builder.py index 20ce65c5588..6bce818514c 100644 --- a/backend/danswer/agent_search/primary_graph/graph_builder.py +++ b/backend/danswer/agent_search/primary_graph/graph_builder.py @@ -2,18 +2,33 @@ from langgraph.graph import START from langgraph.graph import StateGraph +from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph +from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions from danswer.agent_search.primary_graph.edges import continue_to_deep_answer from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions -from danswer.agent_search.primary_graph.nodes import base_wait -from danswer.agent_search.primary_graph.nodes import combine_retrieved_docs -from danswer.agent_search.primary_graph.nodes import custom_retrieve -from danswer.agent_search.primary_graph.nodes import entity_term_extraction -from danswer.agent_search.primary_graph.nodes import final_stuff -from danswer.agent_search.primary_graph.nodes import generate_initial -from danswer.agent_search.primary_graph.nodes import main_decomp_base -from danswer.agent_search.primary_graph.nodes import rewrite -from danswer.agent_search.primary_graph.nodes import verifier +from danswer.agent_search.primary_graph.nodes.base_wait import base_wait +from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import ( + combine_retrieved_docs, +) +from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve +from danswer.agent_search.primary_graph.nodes.decompose import decompose +from danswer.agent_search.primary_graph.nodes.deep_answer_generation import ( + deep_answer_generation, +) +from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start +from danswer.agent_search.primary_graph.nodes.entity_term_extraction import ( + entity_term_extraction, +) +from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff +from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial +from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base +from danswer.agent_search.primary_graph.nodes.rewrite import rewrite +from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import ( + sub_qa_level_aggregator, +) +from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager +from danswer.agent_search.primary_graph.nodes.verifier import verifier from danswer.agent_search.primary_graph.states import QAState @@ -22,6 +37,7 @@ def build_core_graph() -> StateGraph: core_answer_graph = StateGraph(state_schema=QAState) ### Add Nodes ### + core_answer_graph.add_node(node="dummy_start", action=dummy_start) # Re-writing the question core_answer_graph.add_node(node="rewrite", action=rewrite) @@ -45,9 +61,37 @@ def build_core_graph() -> StateGraph: # Initial question decomposition core_answer_graph.add_node(node="main_decomp_base", action=main_decomp_base) + # Build the base QA sub-graph and compile it + compiled_core_qa_graph = build_core_qa_graph().compile() + # Add the compiled base QA sub-graph as a node to the core graph + core_answer_graph.add_node( + node="sub_answers_graph_initial", action=compiled_core_qa_graph + ) + # Checking whether the initial answer is in the ballpark core_answer_graph.add_node(node="base_wait", action=base_wait) + # Decompose the question into sub-questions + core_answer_graph.add_node(node="decompose", action=decompose) + + # Manage the sub-questions + core_answer_graph.add_node(node="sub_qa_manager", action=sub_qa_manager) + + # Build the research QA sub-graph and compile it + compiled_deep_qa_graph = build_deep_qa_graph().compile() + # Add the compiled research QA sub-graph as a node to the core graph + core_answer_graph.add_node(node="sub_answers_graph", action=compiled_deep_qa_graph) + + # Aggregate the sub-questions + core_answer_graph.add_node( + node="sub_qa_level_aggregator", action=sub_qa_level_aggregator + ) + + # aggregate sub questions and answers + core_answer_graph.add_node( + node="deep_answer_generation", action=deep_answer_generation + ) + # A final clean-up step core_answer_graph.add_node(node="final_stuff", action=final_stuff) @@ -61,7 +105,6 @@ def build_core_graph() -> StateGraph: core_answer_graph.add_conditional_edges( source="main_decomp_base", path=continue_to_initial_sub_questions, - path_map={"sub_answers_graph_initial": "sub_answers_graph_initial"}, ) # use the retrieved information to generate the answer @@ -82,7 +125,6 @@ def build_core_graph() -> StateGraph: core_answer_graph.add_conditional_edges( source="sub_qa_manager", path=continue_to_answer_sub_questions, - path_map={"sub_answers_graph": "sub_answers_graph"}, ) core_answer_graph.add_edge( diff --git a/backend/danswer/agent_search/primary_graph/nodes.py b/backend/danswer/agent_search/primary_graph/nodes.py deleted file mode 100644 index 5252e60069b..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes.py +++ /dev/null @@ -1,492 +0,0 @@ -import json -import re -from collections.abc import Sequence -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT -from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT -from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.shared_graph_utils.models import BinaryDecision -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.chat.models import DanswerContext -from danswer.llm.interfaces import LLM - -# Maybe try Partial[QAState] -# from typing import Partial - - -# Transform the initial question into more suitable search queries. -def rewrite(state: QAState) -> dict[str, Any]: - """ - Transform the initial question into more suitable search queries. - - Args: - qa_state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---STARTING GRAPH---") - graph_start_time = datetime.now() - - print("---TRANSFORM QUERY---") - node_start_time = datetime.now() - - question = state["original_question"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), - ) - ] - - # Get the rewritten queries in a defined format - fast_llm: LLM = state["fast_llm"] - llm_response = list( - fast_llm.stream( - prompt=msg, - structured_response_format=RewrittenQueries.model_json_schema(), - ) - ) - - formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) - - return { - "rewritten_queries": formatted_response.rewritten_queries, - "log_messages": generate_log_message( - message="core - rewrite", - node_start_time=node_start_time, - graph_start_time=graph_start_time, - ), - } - - -def custom_retrieve(state: RetrieverState) -> dict[str, Any]: - """ - Retrieve documents - - Args: - retriever_state (dict): The current graph state - - Returns: - state (dict): New key added to state, documents, that contains retrieved documents - """ - print("---RETRIEVE---") - - node_start_time = datetime.now() - - # query = state["rewritten_query"] - - # Retrieval - # TODO: add the actual retrieval, probably from search_tool.run() - documents: list[DanswerContext] = [] - - return { - "base_retrieval_docs": documents, - "log_messages": generate_log_message( - message="core - custom_retrieve", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def combine_retrieved_docs(state: QAState) -> dict[str, Any]: - """ - Dedupe the retrieved docs. - """ - node_start_time = datetime.now() - - base_retrieval_docs: Sequence[DanswerContext] = state["base_retrieval_docs"] - - print(f"Number of docs from steps: {len(base_retrieval_docs)}") - dedupe_docs: list[DanswerContext] = [] - for base_retrieval_doc in base_retrieval_docs: - if not any( - base_retrieval_doc.document_id == doc.document_id for doc in dedupe_docs - ): - dedupe_docs.append(base_retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return { - "deduped_retrieval_docs": dedupe_docs, - "log_messages": generate_log_message( - message="core - combine_retrieved_docs (dedupe)", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def verifier(state: VerifierState) -> dict[str, Any]: - """ - Check whether the document is relevant for the original user question - - Args: - state (VerifierState): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---VERIFY QUTPUT---") - node_start_time = datetime.now() - - question = state["question"] - document_content = state["document"].content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - # Grader - llm: LLM = state["fast_llm"] - response = list( - llm.stream( - prompt=msg, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - raw_response = json.loads(response[0].pretty_repr()) - formatted_response = BinaryDecision.model_validate(raw_response) - - return { - "deduped_retrieval_docs": [state["document"]] - if formatted_response.decision == "yes" - else [], - "log_messages": generate_log_message( - message=f"core - verifier: {formatted_response.decision}", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def generate(state: QAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---GENERATE---") - node_start_time = datetime.now() - - question = state["original_question"] - docs = state["deduped_retrieval_docs"] - - print(f"Number of verified retrieval docs: {len(docs)}") - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) - ) - ] - - # Grader - llm: LLM = state["fast_llm"] - response = list( - llm.stream( - prompt=msg, - structured_response_format=None, - ) - ) - - # Run - # response = rag_chain.invoke({"context": docs, - # "question": question}) - - return { - "base_answer": response[0].pretty_repr(), - "log_messages": generate_log_message( - message="core - generate", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def final_stuff(state: QAState) -> dict[str, Any]: - """ - Invokes the agent model to generate a response based on the current state. Given - the question, it will decide to retrieve using the retriever tool, or simply end. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with the agent response appended to messages - """ - print("---FINAL---") - node_start_time = datetime.now() - - messages = state["log_messages"] - time_ordered_messages = [x.pretty_repr() for x in messages] - time_ordered_messages.sort() - - print("Message Log:") - print("\n".join(time_ordered_messages)) - - initial_sub_qas = state["initial_sub_qas"] - initial_sub_qa_list = [] - for initial_sub_qa in initial_sub_qas: - if initial_sub_qa["sub_answer_check"] == "yes": - initial_sub_qa_list.append( - f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----' - ) - - initial_sub_qa_context = "\n".join(initial_sub_qa_list) - - log_message = generate_log_message( - message="all - final_stuff", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ) - - print(log_message) - print("--------------------------------") - - base_answer = state["base_answer"] - - print(f"Final Base Answer:\n{base_answer}") - print("--------------------------------") - print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}") - print("--------------------------------") - - if not state.get("deep_answer"): - print("No Deep Answer was required") - return { - "log_messages": log_message, - } - - deep_answer = state["deep_answer"] - sub_qas = state["sub_qas"] - sub_qa_list = [] - for sub_qa in sub_qas: - if sub_qa["sub_answer_check"] == "yes": - sub_qa_list.append( - f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----' - ) - - sub_qa_context = "\n".join(sub_qa_list) - - print(f"Final Base Answer:\n{base_answer}") - print("--------------------------------") - print(f"Final Deep Answer:\n{deep_answer}") - print("--------------------------------") - print("Sub Questions and Answers:") - print(sub_qa_context) - - return { - "log_messages": generate_log_message( - message="all - final_stuff", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def base_wait(state: QAState) -> dict[str, Any]: - """ - Ensures that all required steps are completed before proceeding to the next step - - Args: - state (messages): The current state - - Returns: - dict: {} (no operation, just logging) - """ - - print("---Base Wait ---") - node_start_time = datetime.now() - return { - "log_messages": generate_log_message( - message="core - base_wait", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def entity_term_extraction(state: QAState) -> dict[str, Any]: - """ """ - - node_start_time = datetime.now() - - question = state["original_question"] - docs = state["deduped_retrieval_docs"] - - doc_context = format_docs(docs) - - msg = [ - HumanMessage( - content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - - cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr()) - parsed_response = json.loads(cleaned_response) - - return { - "retrieved_entities_relationships": parsed_response, - "log_messages": generate_log_message( - message="deep - entity term extraction", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def generate_initial(state: QAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---GENERATE INITIAL---") - node_start_time = datetime.now() - - question = state["original_question"] - docs = state["deduped_retrieval_docs"] - print(f"Number of verified retrieval docs - base: {len(docs)}") - - sub_question_answers = state["initial_sub_qas"] - - sub_question_answers_list = [] - - _SUB_QUESTION_ANSWER_TEMPLATE = """ - Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n - """ - for sub_question_answer_dict in sub_question_answers: - if ( - sub_question_answer_dict["sub_answer_check"] == "yes" - and len(sub_question_answer_dict["sub_answer"]) > 0 - and sub_question_answer_dict["sub_answer"] != "I don't know" - ): - sub_question_answers_list.append( - _SUB_QUESTION_ANSWER_TEMPLATE.format( - sub_question=sub_question_answer_dict["sub_question"], - sub_answer=sub_question_answer_dict["sub_answer"], - ) - ) - - sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list) - - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT.format( - question=question, - context=format_docs(docs), - answered_sub_questions=sub_question_answer_str, - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - - # Run - # response = rag_chain.invoke({"context": docs, - # "question": question}) - - return { - "base_answer": response.pretty_repr(), - "log_messages": generate_log_message( - message="core - generate initial", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def main_decomp_base(state: QAState) -> dict[str, Any]: - """ - Perform an initial question decomposition, incl. one search term - - Args: - state (messages): The current state - - Returns: - dict: The updated state with initial decomposition - """ - - print("---INITIAL DECOMP---") - node_start_time = datetime.now() - - question = state["original_question"] - - msg = [ - HumanMessage( - content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), - ) - ] - """ - msg = [ - HumanMessage( - content=INITIAL_DECOMPOSITION_PROMPT_BASIC.format(question=question), - ) - ] - """ - - # Get the rewritten queries in a defined format - model = state["fast_llm"] - response = model.invoke(msg) - - content = response.pretty_repr() - list_of_subquestions = clean_and_parse_list_string(content) - # response = model.invoke(msg) - - decomp_list = [] - - for sub_question_nr, sub_question in enumerate(list_of_subquestions): - sub_question_str = sub_question["sub_question"].strip() - # temporarily - sub_question_search_queries = [sub_question["search_term"]] - - decomp_list.append( - { - "sub_question_str": sub_question_str, - "sub_question_search_queries": sub_question_search_queries, - "sub_question_nr": sub_question_nr, - } - ) - - return { - "initial_sub_questions": decomp_list, - "sub_query_start_time": node_start_time, - "log_messages": generate_log_message( - message="core - initial decomp", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/__init__.py b/backend/danswer/agent_search/primary_graph/nodes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/primary_graph/nodes/base_wait.py b/backend/danswer/agent_search/primary_graph/nodes/base_wait.py new file mode 100644 index 00000000000..e46aa530f2c --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/base_wait.py @@ -0,0 +1,27 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def base_wait(state: QAState) -> dict[str, Any]: + """ + Ensures that all required steps are completed before proceeding to the next step + + Args: + state (messages): The current state + + Returns: + dict: {} (no operation, just logging) + """ + + print("---Base Wait ---") + node_start_time = datetime.now() + return { + "log_messages": generate_log_message( + message="core - base_wait", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py new file mode 100644 index 00000000000..a175b74e5a6 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py @@ -0,0 +1,36 @@ +from collections.abc import Sequence +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.context.search.models import InferenceSection + + +def combine_retrieved_docs(state: QAState) -> dict[str, Any]: + """ + Dedupe the retrieved docs. + """ + node_start_time = datetime.now() + + base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"] + + print(f"Number of docs from steps: {len(base_retrieval_docs)}") + dedupe_docs: list[InferenceSection] = [] + for base_retrieval_doc in base_retrieval_docs: + if not any( + base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id + for doc in dedupe_docs + ): + dedupe_docs.append(base_retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + return { + "deduped_retrieval_docs": dedupe_docs, + "log_messages": generate_log_message( + message="core - combine_retrieved_docs (dedupe)", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py new file mode 100644 index 00000000000..deaafbdf411 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py @@ -0,0 +1,52 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import RetrieverState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.context.search.models import InferenceSection +from danswer.context.search.models import SearchRequest +from danswer.context.search.pipeline import SearchPipeline +from danswer.db.engine import get_session_context_manager +from danswer.llm.factory import get_default_llms + + +def custom_retrieve(state: RetrieverState) -> dict[str, Any]: + """ + Retrieve documents + + Args: + retriever_state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE---") + + node_start_time = datetime.now() + + query = state["rewritten_query"] + + # Retrieval + # TODO: add the actual retrieval, probably from search_tool.run() + llm, fast_llm = get_default_llms() + with get_session_context_manager() as db_session: + top_sections = SearchPipeline( + search_request=SearchRequest( + query=query, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=db_session, + ).reranked_sections + print(len(top_sections)) + documents: list[InferenceSection] = [] + + return { + "base_retrieval_docs": documents, + "log_messages": generate_log_message( + message="core - custom_retrieve", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/decompose.py b/backend/danswer/agent_search/primary_graph/nodes/decompose.py new file mode 100644 index 00000000000..351d374b464 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/decompose.py @@ -0,0 +1,78 @@ +import json +import re +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def decompose(state: QAState) -> dict[str, Any]: + """ """ + + node_start_time = datetime.now() + + question = state["original_question"] + base_answer = state["base_answer"] + + # get the entity term extraction dict and properly format it + entity_term_extraction_dict = state["retrieved_entities_relationships"][ + "retrieved_entities_relationships" + ] + + entity_term_extraction_str = format_entity_term_extraction( + entity_term_extraction_dict + ) + + initial_question_answers = state["initial_sub_qas"] + + addressed_question_list = [ + x["sub_question"] + for x in initial_question_answers + if x["sub_answer_check"] == "yes" + ] + failed_question_list = [ + x["sub_question"] + for x in initial_question_answers + if x["sub_answer_check"] == "no" + ] + + msg = [ + HumanMessage( + content=DEEP_DECOMPOSE_PROMPT.format( + question=question, + entity_term_extraction_str=entity_term_extraction_str, + base_answer=base_answer, + answered_sub_questions="\n - ".join(addressed_question_list), + failed_sub_questions="\n - ".join(failed_question_list), + ), + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr()) + parsed_response = json.loads(cleaned_response) + + sub_questions_dict = {} + for sub_question_nr, sub_question_dict in enumerate( + parsed_response["sub_questions"] + ): + sub_question_dict["answered"] = False + sub_question_dict["verified"] = False + sub_questions_dict[sub_question_nr] = sub_question_dict + + return { + "decomposed_sub_questions_dict": sub_questions_dict, + "log_messages": generate_log_message( + message="deep - decompose", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py b/backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py new file mode 100644 index 00000000000..55d41162e5b --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py @@ -0,0 +1,61 @@ +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT +from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace + + +# aggregate sub questions and answers +def deep_answer_generation(state: QAState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---DEEP GENERATE---") + + node_start_time = datetime.now() + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + deep_answer_context = state["core_answer_dynamic_context"] + + print(f"Number of verified retrieval docs - deep: {len(docs)}") + + combined_context = normalize_whitespace( + COMBINED_CONTEXT.format( + deep_answer_context=deep_answer_context, formated_docs=format_docs(docs) + ) + ) + + msg = [ + HumanMessage( + content=MODIFIED_RAG_PROMPT.format( + question=question, combined_context=combined_context + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + return { + "deep_answer": response.content, + "log_messages": generate_log_message( + message="deep - deep answer generation", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/dummy_start.py b/backend/danswer/agent_search/primary_graph/nodes/dummy_start.py new file mode 100644 index 00000000000..62e3dd92a7d --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/dummy_start.py @@ -0,0 +1,11 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import QAState + + +def dummy_start(state: QAState) -> dict[str, Any]: + """ + Dummy node to set the start time + """ + return {"start_time": datetime.now()} diff --git a/backend/danswer/agent_search/primary_graph/nodes/entity_term_extraction.py b/backend/danswer/agent_search/primary_graph/nodes/entity_term_extraction.py new file mode 100644 index 00000000000..b19de6d4f39 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/entity_term_extraction.py @@ -0,0 +1,51 @@ +import json +import re +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.llm.factory import get_default_llms + + +def entity_term_extraction(state: QAState) -> dict[str, Any]: + """Extract entities and terms from the question and context""" + node_start_time = datetime.now() + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + doc_context = format_docs(docs) + + msg = [ + HumanMessage( + content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), + ) + ] + _, fast_llm = get_default_llms() + # Grader + llm_response_list = list( + fast_llm.stream( + prompt=msg, + # structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()}, + # structured_response_format=RewrittenQueries.model_json_schema(), + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + cleaned_response = re.sub(r"```json\n|\n```", "", llm_response) + parsed_response = json.loads(cleaned_response) + + return { + "retrieved_entities_relationships": parsed_response, + "log_messages": generate_log_message( + message="deep - entity term extraction", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/final_stuff.py b/backend/danswer/agent_search/primary_graph/nodes/final_stuff.py new file mode 100644 index 00000000000..be115de8dda --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/final_stuff.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def final_stuff(state: QAState) -> dict[str, Any]: + """ + Invokes the agent model to generate a response based on the current state. Given + the question, it will decide to retrieve using the retriever tool, or simply end. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with the agent response appended to messages + """ + print("---FINAL---") + node_start_time = datetime.now() + + messages = state["log_messages"] + time_ordered_messages = [x.pretty_repr() for x in messages] + time_ordered_messages.sort() + + print("Message Log:") + print("\n".join(time_ordered_messages)) + + initial_sub_qas = state["initial_sub_qas"] + initial_sub_qa_list = [] + for initial_sub_qa in initial_sub_qas: + if initial_sub_qa["sub_answer_check"] == "yes": + initial_sub_qa_list.append( + f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----' + ) + + initial_sub_qa_context = "\n".join(initial_sub_qa_list) + + log_message = generate_log_message( + message="all - final_stuff", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ) + + print(log_message) + print("--------------------------------") + + base_answer = state["base_answer"] + + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}") + print("--------------------------------") + + if not state.get("deep_answer"): + print("No Deep Answer was required") + return { + "log_messages": log_message, + } + + deep_answer = state["deep_answer"] + sub_qas = state["sub_qas"] + sub_qa_list = [] + for sub_qa in sub_qas: + if sub_qa["sub_answer_check"] == "yes": + sub_qa_list.append( + f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----' + ) + + sub_qa_context = "\n".join(sub_qa_list) + + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Final Deep Answer:\n{deep_answer}") + print("--------------------------------") + print("Sub Questions and Answers:") + print(sub_qa_context) + + return { + "log_messages": generate_log_message( + message="all - final_stuff", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/generate.py b/backend/danswer/agent_search/primary_graph/nodes/generate.py new file mode 100644 index 00000000000..9ff707177d8 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/generate.py @@ -0,0 +1,52 @@ +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def generate(state: QAState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---GENERATE---") + node_start_time = datetime.now() + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + print(f"Number of verified retrieval docs: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + ) + ] + + # Grader + llm = state["fast_llm"] + response = list( + llm.stream( + prompt=msg, + structured_response_format=None, + ) + ) + + return { + "base_answer": response[0].pretty_repr(), + "log_messages": generate_log_message( + message="core - generate", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/generate_initial.py b/backend/danswer/agent_search/primary_graph/nodes/generate_initial.py new file mode 100644 index 00000000000..56ad83de96e --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/generate_initial.py @@ -0,0 +1,72 @@ +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def generate_initial(state: QAState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---GENERATE INITIAL---") + node_start_time = datetime.now() + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + print(f"Number of verified retrieval docs - base: {len(docs)}") + + sub_question_answers = state["initial_sub_qas"] + + sub_question_answers_list = [] + + _SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + for sub_question_answer_dict in sub_question_answers: + if ( + sub_question_answer_dict["sub_answer_check"] == "yes" + and len(sub_question_answer_dict["sub_answer"]) > 0 + and sub_question_answer_dict["sub_answer"] != "I don't know" + ): + sub_question_answers_list.append( + _SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=sub_question_answer_dict["sub_question"], + sub_answer=sub_question_answer_dict["sub_answer"], + ) + ) + + sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list) + + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT.format( + question=question, + context=format_docs(docs), + answered_sub_questions=sub_question_answer_str, + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + return { + "base_answer": response.pretty_repr(), + "log_messages": generate_log_message( + message="core - generate initial", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py b/backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py new file mode 100644 index 00000000000..f07100d61af --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py @@ -0,0 +1,64 @@ +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def main_decomp_base(state: QAState) -> dict[str, Any]: + """ + Perform an initial question decomposition, incl. one search term + + Args: + state (messages): The current state + + Returns: + dict: The updated state with initial decomposition + """ + + print("---INITIAL DECOMP---") + node_start_time = datetime.now() + + question = state["original_question"] + + msg = [ + HumanMessage( + content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), + ) + ] + + # Get the rewritten queries in a defined format + model = state["fast_llm"] + response = model.invoke(msg) + + content = response.pretty_repr() + list_of_subquestions = clean_and_parse_list_string(content) + + decomp_list = [] + + for sub_question_nr, sub_question in enumerate(list_of_subquestions): + sub_question_str = sub_question["sub_question"].strip() + # temporarily + sub_question_search_queries = [sub_question["search_term"]] + + decomp_list.append( + { + "sub_question_str": sub_question_str, + "sub_question_search_queries": sub_question_search_queries, + "sub_question_nr": sub_question_nr, + } + ) + + return { + "initial_sub_questions": decomp_list, + "sub_query_start_time": node_start_time, + "log_messages": generate_log_message( + message="core - initial decomp", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/rewrite.py b/backend/danswer/agent_search/primary_graph/nodes/rewrite.py new file mode 100644 index 00000000000..07cbba5432c --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/rewrite.py @@ -0,0 +1,55 @@ +import json +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def rewrite(state: QAState) -> dict[str, Any]: + """ + Transform the initial question into more suitable search queries. + + Args: + qa_state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---STARTING GRAPH---") + graph_start_time = datetime.now() + + print("---TRANSFORM QUERY---") + node_start_time = datetime.now() + + question = state["original_question"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + + # Get the rewritten queries in a defined format + fast_llm = state["fast_llm"] + llm_response = list( + fast_llm.stream( + prompt=msg, + structured_response_format=RewrittenQueries.model_json_schema(), + ) + ) + + formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) + + return { + "rewritten_queries": formatted_response.rewritten_queries, + "log_messages": generate_log_message( + message="core - rewrite", + node_start_time=node_start_time, + graph_start_time=graph_start_time, + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/sub_qa_level_aggregator.py b/backend/danswer/agent_search/primary_graph/nodes/sub_qa_level_aggregator.py new file mode 100644 index 00000000000..8d53ccc239b --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/sub_qa_level_aggregator.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +# aggregate sub questions and answers +def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]: + sub_qas = state["sub_qas"] + + node_start_time = datetime.now() + + dynamic_context_list = [ + "Below you will find useful information to answer the original question:" + ] + checked_sub_qas = [] + + for core_answer_sub_qa in sub_qas: + question = core_answer_sub_qa["sub_question"] + answer = core_answer_sub_qa["sub_answer"] + verified = core_answer_sub_qa["sub_answer_check"] + + if verified == "yes": + dynamic_context_list.append( + f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n" + ) + checked_sub_qas.append({"sub_question": question, "sub_answer": answer}) + dynamic_context = "\n".join(dynamic_context_list) + + return { + "core_answer_dynamic_context": dynamic_context, + "checked_sub_qas": checked_sub_qas, + "log_messages": generate_log_message( + message="deep - sub qa level aggregator", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py b/backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py new file mode 100644 index 00000000000..6e81dfb5dea --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py @@ -0,0 +1,28 @@ +from datetime import datetime +from typing import Any + +from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def sub_qa_manager(state: QAState) -> dict[str, Any]: + """ """ + + node_start_time = datetime.now() + + sub_questions_dict = state["decomposed_sub_questions_dict"] + + sub_questions = {} + + for sub_question_nr, sub_question_dict in sub_questions_dict.items(): + sub_questions[sub_question_nr] = sub_question_dict["sub_question"] + + return { + "sub_questions": sub_questions, + "num_new_question_iterations": 0, + "log_messages": generate_log_message( + message="deep - sub qa manager", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/verifier.py b/backend/danswer/agent_search/primary_graph/nodes/verifier.py new file mode 100644 index 00000000000..1efdba03109 --- /dev/null +++ b/backend/danswer/agent_search/primary_graph/nodes/verifier.py @@ -0,0 +1,59 @@ +import json +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from danswer.agent_search.primary_graph.states import VerifierState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.agent_search.shared_graph_utils.utils import generate_log_message + + +def verifier(state: VerifierState) -> dict[str, Any]: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print("---VERIFY QUTPUT---") + node_start_time = datetime.now() + + question = state["question"] + document_content = state["document"].combined_content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + # Grader + llm = state["fast_llm"] + response = list( + llm.stream( + prompt=msg, + structured_response_format=BinaryDecision.model_json_schema(), + ) + ) + + raw_response = json.loads(response[0].pretty_repr()) + formatted_response = BinaryDecision.model_validate(raw_response) + + return { + "deduped_retrieval_docs": [state["document"]] + if formatted_response.decision == "yes" + else [], + "log_messages": generate_log_message( + message=f"core - verifier: {formatted_response.decision}", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/danswer/agent_search/primary_graph/states.py b/backend/danswer/agent_search/primary_graph/states.py index b692ca4cf1c..2e59fedfcd7 100644 --- a/backend/danswer/agent_search/primary_graph/states.py +++ b/backend/danswer/agent_search/primary_graph/states.py @@ -7,8 +7,8 @@ from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messages -from danswer.chat.models import DanswerContext -from danswer.llm.interfaces import LLM +from danswer.agent_search.shared_graph_utils.models import RewrittenQueries +from danswer.context.search.models import InferenceSection class QAState(TypedDict): @@ -18,7 +18,7 @@ class QAState(TypedDict): # start time for parallel initial sub-questionn thread sub_query_start_time: datetime log_messages: Annotated[Sequence[BaseMessage], add_messages] - rewritten_queries: list[str] + rewritten_queries: RewrittenQueries sub_questions: list[dict] initial_sub_questions: list[dict] ranked_subquestion_ids: list[int] @@ -28,13 +28,13 @@ class QAState(TypedDict): sub_qas: Annotated[Sequence[dict], operator.add] initial_sub_qas: Annotated[Sequence[dict], operator.add] checked_sub_qas: Annotated[Sequence[dict], operator.add] - base_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] - deduped_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] - reranked_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] + deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] + reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] retrieved_entities_relationships: dict questions_context: list[dict] qa_level: int - top_chunks: list[DanswerContext] + top_chunks: list[InferenceSection] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] num_new_question_iterations: int core_answer_dynamic_context: str @@ -42,8 +42,6 @@ class QAState(TypedDict): initial_base_answer: str base_answer: str deep_answer: str - primary_llm: LLM - fast_llm: LLM class QAOuputState(TypedDict): @@ -54,9 +52,9 @@ class QAOuputState(TypedDict): sub_qas: Annotated[Sequence[dict], operator.add] initial_sub_qas: Annotated[Sequence[dict], operator.add] checked_sub_qas: Annotated[Sequence[dict], operator.add] - reranked_retrieval_docs: Annotated[Sequence[DanswerContext], operator.add] + reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] retrieved_entities_relationships: dict - top_chunks: list[DanswerContext] + top_chunks: list[InferenceSection] sub_question_top_chunks: Annotated[Sequence[dict], operator.add] base_answer: str deep_answer: str @@ -65,15 +63,11 @@ class QAOuputState(TypedDict): class RetrieverState(TypedDict): # The state for the parallel Retrievers. They each need to see only one query rewritten_query: str - primary_llm: LLM - fast_llm: LLM graph_start_time: datetime class VerifierState(TypedDict): # The state for the parallel verification step. Each node execution need to see only one question/doc pair - document: DanswerContext + document: InferenceSection question: str - primary_llm: LLM - fast_llm: LLM graph_start_time: datetime diff --git a/backend/danswer/agent_search/research_qa_sub_graph/nodes.py b/backend/danswer/agent_search/research_qa_sub_graph/nodes.py deleted file mode 100644 index 2c982cd7143..00000000000 --- a/backend/danswer/agent_search/research_qa_sub_graph/nodes.py +++ /dev/null @@ -1,308 +0,0 @@ -import json -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.base_qa_sub_graph.states import BaseQAState -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.research_qa_sub_graph.prompts import SUB_CHECK_PROMPT -from danswer.agent_search.research_qa_sub_graph.states import ResearchQAState -from danswer.agent_search.shared_graph_utils.models import BinaryDecision -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.chat.models import DanswerContext -from danswer.llm.interfaces import LLM - - -def sub_rewrite(state: ResearchQAState) -> dict[str, Any]: - """ - Transform the initial question into more suitable search queries. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - - print("---SUB TRANSFORM QUERY---") - node_start_time = datetime.now() - - question = state["sub_question"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), - ) - ] - fast_llm: LLM = state["fast_llm"] - llm_response = list( - fast_llm.stream( - prompt=msg, - structured_response_format=RewrittenQueries.model_json_schema(), - ) - ) - - # Get the rewritten queries in a defined format - rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) - - print(f"rewritten_queries: {rewritten_queries}") - - rewritten_queries = RewrittenQueries( - rewritten_queries=[ - "music hard to listen to", - "Music that is not fun or pleasant", - ] - ) - - print(f"hardcoded rewritten_queries: {rewritten_queries}") - - return { - "sub_question_rewritten_queries": rewritten_queries, - "log_messages": generate_log_message( - message="sub - rewrite", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: - """ - Retrieve documents - - Args: - state (dict): The current graph state - - Returns: - state (dict): New key added to state, documents, that contains retrieved documents - """ - print("---RETRIEVE SUB---") - node_start_time = datetime.now() - - # Retrieval - # TODO: add the actual retrieval, probably from search_tool.run() - documents: list[DanswerContext] = [] - - return { - "sub_question_base_retrieval_docs": documents, - "log_messages": generate_log_message( - message="sub - custom_retrieve", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]: - """ - Dedupe the retrieved docs. - """ - node_start_time = datetime.now() - - sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] - - print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") - dedupe_docs = [] - for base_retrieval_doc in sub_question_base_retrieval_docs: - if base_retrieval_doc not in dedupe_docs: - dedupe_docs.append(base_retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return { - "sub_question_deduped_retrieval_docs": dedupe_docs, - "log_messages": generate_log_message( - message="sub - combine_retrieved_docs (dedupe)", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_verifier(state: VerifierState) -> dict[str, Any]: - """ - Check whether the document is relevant for the original user question - - Args: - state (VerifierState): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---SUB VERIFY QUTPUT---") - node_start_time = datetime.now() - - question = state["question"] - document_content = state["document"].content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - - if response.pretty_repr().lower() == "yes": - return { - "sub_question_verified_retrieval_docs": [state["document"]], - "log_messages": generate_log_message( - message="sub - verifier: yes", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - else: - return { - "sub_question_verified_retrieval_docs": [], - "log_messages": generate_log_message( - message="sub - verifier: no", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_generate(state: ResearchQAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---SUB GENERATE---") - node_start_time = datetime.now() - - question = state["sub_question"] - docs = state["sub_question_verified_retrieval_docs"] - - print(f"Number of verified retrieval docs for sub-question: {len(docs)}") - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) - ) - ] - - # Grader - if len(docs) > 0: - model = state["fast_llm"] - response = model.invoke(msg).pretty_repr() - else: - response = "" - - return { - "sub_question_answer": response, - "log_messages": generate_log_message( - message="sub - generate", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_final_format(state: ResearchQAState) -> dict[str, Any]: - """ - Create the final output for the QA subgraph - """ - - print("---SUB FINAL FORMAT---") - node_start_time = datetime.now() - - return { - # TODO: Type this - "sub_qas": [ - { - "sub_question": state["sub_question"], - "sub_answer": state["sub_question_answer"], - "sub_question_nr": state["sub_question_nr"], - "sub_answer_check": state["sub_question_answer_check"], - } - ], - "log_messages": generate_log_message( - message="sub - final format", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -# nodes - - -def sub_qa_check(state: ResearchQAState) -> dict[str, Any]: - """ - Check whether the final output satisfies the original user question - - Args: - state (messages): The current state - - Returns: - dict: The updated state with the final decision - """ - - print("---CHECK SUB QUTPUT---") - node_start_time = datetime.now() - - sub_answer = state["sub_question_answer"] - sub_question = state["sub_question"] - - msg = [ - HumanMessage( - content=SUB_CHECK_PROMPT.format( - sub_question=sub_question, sub_answer=sub_answer - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = list( - model.stream( - prompt=msg, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - raw_response = json.loads(response[0].pretty_repr()) - formatted_response = BinaryDecision.model_validate(raw_response) - - return { - "sub_question_answer_check": formatted_response.decision, - "log_messages": generate_log_message( - message=f"sub - qa check: {formatted_response.decision}", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } - - -def sub_dummy(state: BaseQAState) -> dict[str, Any]: - """ - Dummy step - """ - - print("---Sub Dummy---") - - return { - "log_messages": generate_log_message( - message="sub - dummy", - node_start_time=datetime.now(), - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/shared_graph_utils/prompts.py b/backend/danswer/agent_search/shared_graph_utils/prompts.py index e52234d40d2..3fb690c3cb4 100644 --- a/backend/danswer/agent_search/shared_graph_utils/prompts.py +++ b/backend/danswer/agent_search/shared_graph_utils/prompts.py @@ -93,7 +93,7 @@ Answer:""" -_ORIG_DEEP_DECOMPOSE_PROMPT = """ \n +ORIG_DEEP_DECOMPOSE_PROMPT = """ \n An initial user question needs to be answered. An initial answer has been provided but it wasn't quite good enough. Also, some sub-questions had been answered and this information has been used to provide the initial answer. Some other subquestions may have been suggested based on little knowledge, but they @@ -170,7 +170,7 @@ "search_term": }}, ...]}} """ -_DEEP_DECOMPOSE_PROMPT = """ \n +DEEP_DECOMPOSE_PROMPT = """ \n An initial user question needs to be answered. An initial answer has been provided but it wasn't quite good enough. Also, some sub-questions had been answered and this information has been used to provide the initial answer. Some other subquestions may have been suggested based on little knowledge, but they @@ -242,7 +242,7 @@ "search_term": }}, ...]}} """ -_DECOMPOSE_PROMPT = """ \n +DECOMPOSE_PROMPT = """ \n For an initial user question, please generate at 5-10 individual sub-questions whose answers would help \n to answer the initial question. The individual questions should be answerable by a good RAG system. So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the @@ -290,7 +290,7 @@ ...]}} """ #### Consolidations -_COMBINED_CONTEXT = """------- +COMBINED_CONTEXT = """------- Below you will find useful information to answer the original question. First, you see a number of sub-questions with their answers. This information should be considered to be more focussed and somewhat more specific to the original question as it tries to contextualized facts. @@ -303,7 +303,7 @@ ---------------- """ -_SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- +SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- Below you will find a question that we ultimately want to answer (the original question) and a list of motivations in arbitrary order for generated sub-questions that are supposed to help us answering the original question. The motivations are formatted as : . diff --git a/backend/danswer/agent_search/shared_graph_utils/utils.py b/backend/danswer/agent_search/shared_graph_utils/utils.py index 4bfd9fa87a8..95faf287ae0 100644 --- a/backend/danswer/agent_search/shared_graph_utils/utils.py +++ b/backend/danswer/agent_search/shared_graph_utils/utils.py @@ -6,7 +6,7 @@ from datetime import timedelta from typing import Any -from danswer.chat.models import DanswerContext +from danswer.context.search.models import InferenceSection def normalize_whitespace(text: str) -> str: @@ -17,8 +17,8 @@ def normalize_whitespace(text: str) -> str: # Post-processing -def format_docs(docs: Sequence[DanswerContext]) -> str: - return "\n\n".join(doc.content for doc in docs) +def format_docs(docs: Sequence[InferenceSection]) -> str: + return "\n\n".join(doc.combined_content for doc in docs) def clean_and_parse_list_string(json_string: str) -> list[dict]: diff --git a/backend/danswer/tools/message.py b/backend/danswer/tools/message.py index b0259c29b2a..6d261a8bf11 100644 --- a/backend/danswer/tools/message.py +++ b/backend/danswer/tools/message.py @@ -25,6 +25,9 @@ class ToolCallSummary(BaseModel__v1): tool_call_request: AIMessage tool_call_result: ToolMessage + class Config: + arbitrary_types_allowed = True + def tool_call_tokens( tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index eb2cdd04b5a..5054c641480 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -27,9 +27,7 @@ jira==3.5.1 jsonref==1.1.0 trafilatura==1.12.2 langchain==0.3.7 -langchain-community==0.3.7 langchain-core==0.3.20 -langchain-huggingface==0.1.2 langchain-openai==0.2.9 langchain-text-splitters==0.3.2 langchainhub==0.1.21 From f5e28e0f54b454bd768c8a6e8fb13004fd7bb2a1 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 9 Dec 2024 11:34:11 -0800 Subject: [PATCH 06/10] minor refactor --- .../agent_search/core_qa_graph/edges.py | 13 +- .../core_qa_graph/graph_builder.py | 123 +++++++++--------- .../nodes/combine_retrieved_docs.py | 2 +- .../core_qa_graph/nodes/custom_retrieve.py | 25 ++-- .../agent_search/core_qa_graph/nodes/dummy.py | 2 +- .../core_qa_graph/nodes/final_format.py | 2 +- .../core_qa_graph/nodes/generate.py | 2 +- .../core_qa_graph/nodes/qa_check.py | 2 +- .../core_qa_graph/nodes/rewrite.py | 2 +- .../core_qa_graph/nodes/verifier.py | 2 +- .../agent_search/core_qa_graph/states.py | 1 + .../agent_search/deep_qa_graph/edges.py | 11 +- .../deep_qa_graph/graph_builder.py | 77 ++++++----- .../nodes/combine_retrieved_docs.py | 2 +- .../deep_qa_graph/nodes/custom_retrieve.py | 2 +- .../agent_search/deep_qa_graph/nodes/dummy.py | 2 +- .../deep_qa_graph/nodes/final_format.py | 2 +- .../deep_qa_graph/nodes/generate.py | 2 +- .../deep_qa_graph/nodes/qa_check.py | 2 +- .../deep_qa_graph/nodes/rewrite.py | 2 +- .../deep_qa_graph/nodes/verifier.py | 2 +- backend/danswer/agent_search/primary_state.py | 16 +++ backend/danswer/agent_search/test.py | 120 +++++++++++++++++ .../util_sub_graphs/collect_docs.py | 44 +++++++ .../util_sub_graphs/dedupe_retrieved_docs.py | 65 +++++++++ 25 files changed, 385 insertions(+), 140 deletions(-) create mode 100644 backend/danswer/agent_search/primary_state.py create mode 100644 backend/danswer/agent_search/test.py create mode 100644 backend/danswer/agent_search/util_sub_graphs/collect_docs.py create mode 100644 backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py diff --git a/backend/danswer/agent_search/core_qa_graph/edges.py b/backend/danswer/agent_search/core_qa_graph/edges.py index 0d0c2b3a50d..3a175399c7c 100644 --- a/backend/danswer/agent_search/core_qa_graph/edges.py +++ b/backend/danswer/agent_search/core_qa_graph/edges.py @@ -1,5 +1,4 @@ -from collections.abc import Hashable -from typing import Union +from typing import Literal from langgraph.types import Send @@ -8,30 +7,30 @@ from danswer.agent_search.primary_graph.states import VerifierState -def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]: +def continue_to_verifier(state: BaseQAState) -> Literal["verifier"]: # Routes each de-douped retrieved doc to the verifier step - in parallel # Notice the 'Send()' API that takes care of the parallelization return [ Send( - "sub_verifier", + "verifier", VerifierState( document=doc, question=state["sub_question_str"], graph_start_time=state["graph_start_time"], ), ) - for doc in state["sub_question_base_retrieval_docs"] + for doc in state["sub_question_retrieval_docs"] ] -def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]: +def continue_to_retrieval(state: BaseQAState) -> Literal["custom_retrieve"]: # Routes re-written queries to the (parallel) retrieval steps # Notice the 'Send()' API that takes care of the parallelization rewritten_queries = state["sub_question_search_queries"].rewritten_queries return [ Send( - "sub_custom_retrieve", + "custom_retrieve", RetrieverState( rewritten_query=query, graph_start_time=state["graph_start_time"], diff --git a/backend/danswer/agent_search/core_qa_graph/graph_builder.py b/backend/danswer/agent_search/core_qa_graph/graph_builder.py index 1031d945cc5..985b8669f17 100644 --- a/backend/danswer/agent_search/core_qa_graph/graph_builder.py +++ b/backend/danswer/agent_search/core_qa_graph/graph_builder.py @@ -2,102 +2,109 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval -from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier -from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import ( - sub_combine_retrieved_docs, -) +from danswer.agent_search.core_qa_graph.edges import continue_to_retrieval +from danswer.agent_search.core_qa_graph.edges import continue_to_verifier from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import ( - sub_custom_retrieve, + custom_retrieve, ) -from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy +from danswer.agent_search.core_qa_graph.nodes.dummy import dummy from danswer.agent_search.core_qa_graph.nodes.final_format import ( - sub_final_format, + final_format, ) -from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate -from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check -from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite -from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier +from danswer.agent_search.core_qa_graph.nodes.generate import generate +from danswer.agent_search.core_qa_graph.nodes.qa_check import qa_check +from danswer.agent_search.core_qa_graph.nodes.rewrite import rewrite +from danswer.agent_search.core_qa_graph.nodes.verifier import verifier from danswer.agent_search.core_qa_graph.states import BaseQAOutputState from danswer.agent_search.core_qa_graph.states import BaseQAState from danswer.agent_search.core_qa_graph.states import CoreQAInputState +from danswer.agent_search.util_sub_graphs.collect_docs import collect_docs +from danswer.agent_search.util_sub_graphs.dedupe_retrieved_docs import ( + build_dedupe_retrieved_docs_graph, +) + +# from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import combine_retrieved_docs def build_core_qa_graph() -> StateGraph: - sub_answers_initial = StateGraph( + answers_initial = StateGraph( state_schema=BaseQAState, output=BaseQAOutputState, ) ### Add nodes ### - sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy) - sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite) - sub_answers_initial.add_node( - node="sub_custom_retrieve", - action=sub_custom_retrieve, + answers_initial.add_node(node="dummy", action=dummy) + answers_initial.add_node(node="rewrite", action=rewrite) + answers_initial.add_node( + node="custom_retrieve", + action=custom_retrieve, ) - sub_answers_initial.add_node( - node="sub_combine_retrieved_docs", - action=sub_combine_retrieved_docs, + # answers_initial.add_node( + # node="collect_docs", + # action=collect_docs, + # ) + build_dedupe_retrieved_docs_graph().compile() + answers_initial.add_node( + node="collect_docs", + action=collect_docs, ) - sub_answers_initial.add_node( - node="sub_verifier", - action=sub_verifier, + answers_initial.add_node( + node="verifier", + action=verifier, ) - sub_answers_initial.add_node( - node="sub_generate", - action=sub_generate, + answers_initial.add_node( + node="generate", + action=generate, ) - sub_answers_initial.add_node( - node="sub_qa_check", - action=sub_qa_check, + answers_initial.add_node( + node="qa_check", + action=qa_check, ) - sub_answers_initial.add_node( - node="sub_final_format", - action=sub_final_format, + answers_initial.add_node( + node="final_format", + action=final_format, ) ### Add edges ### - sub_answers_initial.add_edge(START, "sub_dummy") - sub_answers_initial.add_edge("sub_dummy", "sub_rewrite") + answers_initial.add_edge(START, "dummy") + answers_initial.add_edge("dummy", "rewrite") - sub_answers_initial.add_conditional_edges( - source="sub_rewrite", - path=sub_continue_to_retrieval, + answers_initial.add_conditional_edges( + source="rewrite", + path=continue_to_retrieval, ) - sub_answers_initial.add_edge( - start_key="sub_custom_retrieve", - end_key="sub_combine_retrieved_docs", + answers_initial.add_edge( + start_key="custom_retrieve", + end_key="collect_docs", ) - sub_answers_initial.add_conditional_edges( - source="sub_combine_retrieved_docs", - path=sub_continue_to_verifier, - path_map=["sub_verifier"], + answers_initial.add_conditional_edges( + source="collect_docs", + path=continue_to_verifier, ) - sub_answers_initial.add_edge( - start_key="sub_verifier", - end_key="sub_generate", + answers_initial.add_edge( + start_key="verifier", + end_key="generate", ) - sub_answers_initial.add_edge( - start_key="sub_generate", - end_key="sub_qa_check", + answers_initial.add_edge( + start_key="generate", + end_key="qa_check", ) - sub_answers_initial.add_edge( - start_key="sub_qa_check", - end_key="sub_final_format", + answers_initial.add_edge( + start_key="qa_check", + end_key="final_format", ) - sub_answers_initial.add_edge( - start_key="sub_final_format", + answers_initial.add_edge( + start_key="final_format", end_key=END, ) - # sub_answers_graph = sub_answers_initial.compile() - return sub_answers_initial + # answers_graph = answers_initial.compile() + return answers_initial if __name__ == "__main__": diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py index 64f657363ef..bb45f083328 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py @@ -6,7 +6,7 @@ from danswer.context.search.models import InferenceSection -def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]: +def combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]: """ Dedupe the retrieved docs. """ diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py index 4793c511b7c..1fe4a840c99 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py @@ -1,8 +1,10 @@ import datetime -from typing import Any from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.agent_search.util_sub_graphs.collect_docs import CollectDocsInput +from danswer.agent_search.util_sub_graphs.dedupe_retrieved_docs import ( + DedupeRetrievedDocsInput, +) from danswer.context.search.models import InferenceSection from danswer.context.search.models import SearchRequest from danswer.context.search.pipeline import SearchPipeline @@ -10,7 +12,7 @@ from danswer.llm.factory import get_default_llms -def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: +def custom_retrieve(state: RetrieverState) -> DedupeRetrievedDocsInput: """ Retrieve documents @@ -22,7 +24,7 @@ def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: """ print("---RETRIEVE SUB---") - node_start_time = datetime.datetime.now() + datetime.datetime.now() rewritten_query = state["rewritten_query"] @@ -41,11 +43,10 @@ def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: db_session=db_session, ).reranked_sections - return { - "sub_question_base_retrieval_docs": documents, - "log_messages": generate_log_message( - message="sub - custom_retrieve", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } + return CollectDocsInput( + sub_question_retrieval_docs=documents, + ) + + # return DedupeRetrievedDocsInput( + # pre_dedupe_docs=documents, + # ) diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py b/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py index be018334aa9..e8718838300 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py @@ -5,7 +5,7 @@ from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def sub_dummy(state: BaseQAState) -> dict[str, Any]: +def dummy(state: BaseQAState) -> dict[str, Any]: """ Dummy step """ diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py b/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py index 8b24cc0a6ed..c75262d08ad 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py @@ -3,7 +3,7 @@ from danswer.agent_search.core_qa_graph.states import BaseQAState -def sub_final_format(state: BaseQAState) -> dict[str, Any]: +def final_format(state: BaseQAState) -> dict[str, Any]: """ Create the final output for the QA subgraph """ diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/generate.py b/backend/danswer/agent_search/core_qa_graph/nodes/generate.py index ed1c3661be9..21416d7473e 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/generate.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/generate.py @@ -11,7 +11,7 @@ from danswer.llm.factory import get_default_llms -def sub_generate(state: BaseQAState) -> dict[str, Any]: +def generate(state: BaseQAState) -> dict[str, Any]: """ Generate answer diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py b/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py index adff7a3d4cd..3226bac00c0 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py @@ -10,7 +10,7 @@ from danswer.llm.factory import get_default_llms -def sub_qa_check(state: BaseQAState) -> dict[str, Any]: +def qa_check(state: BaseQAState) -> dict[str, Any]: """ Check if the sub-question answer is satisfactory. diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py b/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py index e5841efbd46..57118227b9e 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py @@ -11,7 +11,7 @@ from danswer.llm.factory import get_default_llms -def sub_rewrite(state: BaseQAState) -> dict[str, Any]: +def rewrite(state: BaseQAState) -> dict[str, Any]: """ Transform the initial question into more suitable search queries. diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py b/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py index 8bcaee1a5b7..e5b89a2761d 100644 --- a/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py +++ b/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py @@ -11,7 +11,7 @@ from danswer.llm.factory import get_default_llms -def sub_verifier(state: VerifierState) -> dict[str, Any]: +def verifier(state: VerifierState) -> dict[str, Any]: """ Check whether the document is relevant for the original user question diff --git a/backend/danswer/agent_search/core_qa_graph/states.py b/backend/danswer/agent_search/core_qa_graph/states.py index 5c53df06275..45cec7e58cf 100644 --- a/backend/danswer/agent_search/core_qa_graph/states.py +++ b/backend/danswer/agent_search/core_qa_graph/states.py @@ -37,6 +37,7 @@ class BaseQAState(TypedDict): sub_question_rewritten_queries: list[str] sub_question_str: str sub_question_search_queries: RewrittenQueries + deduped_retrieval_docs: list[InferenceSection] sub_question_nr: int sub_question_base_retrieval_docs: Annotated[ Sequence[InferenceSection], operator.add diff --git a/backend/danswer/agent_search/deep_qa_graph/edges.py b/backend/danswer/agent_search/deep_qa_graph/edges.py index 980af9159b9..6f0f143aa03 100644 --- a/backend/danswer/agent_search/deep_qa_graph/edges.py +++ b/backend/danswer/agent_search/deep_qa_graph/edges.py @@ -1,5 +1,4 @@ from collections.abc import Hashable -from typing import Union from langgraph.types import Send @@ -8,7 +7,7 @@ from danswer.agent_search.primary_graph.states import VerifierState -def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]: +def continue_to_verifier(state: ResearchQAState) -> list[Hashable]: # Routes each de-douped retrieved doc to the verifier step - in parallel # Notice the 'Send()' API that takes care of the parallelization @@ -18,8 +17,6 @@ def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Has VerifierState( document=doc, question=state["sub_question"], - primary_llm=state["primary_llm"], - fast_llm=state["fast_llm"], graph_start_time=state["graph_start_time"], ), ) @@ -27,9 +24,9 @@ def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Has ] -def sub_continue_to_retrieval( +def continue_to_retrieval( state: ResearchQAState, -) -> Union[Hashable, list[Hashable]]: +) -> list[Hashable]: # Routes re-written queries to the (parallel) retrieval steps # Notice the 'Send()' API that takes care of the parallelization return [ @@ -37,8 +34,6 @@ def sub_continue_to_retrieval( "sub_custom_retrieve", RetrieverState( rewritten_query=query, - primary_llm=state["primary_llm"], - fast_llm=state["fast_llm"], graph_start_time=state["graph_start_time"], ), ) diff --git a/backend/danswer/agent_search/deep_qa_graph/graph_builder.py b/backend/danswer/agent_search/deep_qa_graph/graph_builder.py index 90c7aebeb0f..0e0cae88127 100644 --- a/backend/danswer/agent_search/deep_qa_graph/graph_builder.py +++ b/backend/danswer/agent_search/deep_qa_graph/graph_builder.py @@ -2,92 +2,89 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval -from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier +from danswer.agent_search.core_qa_graph.nodes.rewrite import rewrite +from danswer.agent_search.deep_qa_graph.edges import continue_to_retrieval +from danswer.agent_search.deep_qa_graph.edges import continue_to_verifier from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import ( - sub_combine_retrieved_docs, + combine_retrieved_docs, ) -from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve -from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy -from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format -from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate -from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check -from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier +from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import custom_retrieve +from danswer.agent_search.deep_qa_graph.nodes.final_format import final_format +from danswer.agent_search.deep_qa_graph.nodes.generate import generate +from danswer.agent_search.deep_qa_graph.nodes.qa_check import qa_check +from danswer.agent_search.deep_qa_graph.nodes.verifier import verifier from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState from danswer.agent_search.deep_qa_graph.states import ResearchQAState def build_deep_qa_graph() -> StateGraph: # Define the nodes we will cycle between - sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState) + answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState) ### Add Nodes ### # Dummy node for initial processing - sub_answers.add_node(node="sub_dummy", action=sub_dummy) + # answers.add_node(node="dummy", action=dummy) + answers.add_node(node="rewrite", action=rewrite) # The retrieval step - sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve) + answers.add_node(node="custom_retrieve", action=custom_retrieve) # The dedupe step - sub_answers.add_node( - node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs - ) + answers.add_node(node="combine_retrieved_docs", action=combine_retrieved_docs) # Verifying retrieved information - sub_answers.add_node(node="sub_verifier", action=sub_verifier) + answers.add_node(node="verifier", action=verifier) # Generating the response - sub_answers.add_node(node="sub_generate", action=sub_generate) + answers.add_node(node="generate", action=generate) # Checking the quality of the answer - sub_answers.add_node(node="sub_qa_check", action=sub_qa_check) + answers.add_node(node="qa_check", action=qa_check) # Final formatting of the response - sub_answers.add_node(node="sub_final_format", action=sub_final_format) + answers.add_node(node="final_format", action=final_format) ### Add Edges ### # Generate multiple sub-questions - sub_answers.add_edge(start_key=START, end_key="sub_rewrite") + answers.add_edge(start_key=START, end_key="rewrite") # For each sub-question, perform a retrieval in parallel - sub_answers.add_conditional_edges( - source="sub_rewrite", - path=sub_continue_to_retrieval, - path_map=["sub_custom_retrieve"], + answers.add_conditional_edges( + source="rewrite", + path=continue_to_retrieval, + path_map=["custom_retrieve"], ) # Combine the retrieved docs for each sub-question from the parallel retrievals - sub_answers.add_edge( - start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs" - ) + answers.add_edge(start_key="custom_retrieve", end_key="combine_retrieved_docs") # Go over all of the combined retrieved docs and verify them against the original question - sub_answers.add_conditional_edges( - source="sub_combine_retrieved_docs", - path=sub_continue_to_verifier, - path_map=["sub_verifier"], + answers.add_conditional_edges( + source="combine_retrieved_docs", + path=continue_to_verifier, + path_map=["verifier"], ) # Generate an answer for each verified retrieved doc - sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate") + answers.add_edge(start_key="verifier", end_key="generate") # Check the quality of the answer - sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check") + answers.add_edge(start_key="generate", end_key="qa_check") - sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format") + answers.add_edge(start_key="qa_check", end_key="final_format") - sub_answers.add_edge(start_key="sub_final_format", end_key=END) + answers.add_edge(start_key="final_format", end_key=END) - return sub_answers + return answers if __name__ == "__main__": # TODO: add the actual question - inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"} - sub_answers_graph = build_deep_qa_graph() - compiled_sub_answers = sub_answers_graph.compile() - output = compiled_sub_answers.invoke(inputs) + inputs = {"question": "Whose music is kind of hard to easily enjoy?"} + answers_graph = build_deep_qa_graph() + compiled_answers = answers_graph.compile() + output = compiled_answers.invoke(inputs) print("\nOUTPUT:") print(output) diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py index 542c823ae33..f2a5be3754a 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py @@ -5,7 +5,7 @@ from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]: +def combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]: """ Dedupe the retrieved docs. """ diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py index b041de35768..3e00a94a23d 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py @@ -6,7 +6,7 @@ from danswer.context.search.models import InferenceSection -def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]: +def custom_retrieve(state: RetrieverState) -> dict[str, Any]: """ Retrieve documents diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py b/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py index 976c81744de..e5ada974d23 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py @@ -5,7 +5,7 @@ from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def sub_dummy(state: BaseQAState) -> dict[str, Any]: +def dummy(state: BaseQAState) -> dict[str, Any]: """ Dummy step """ diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py b/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py index d3a8706f9e1..6515c2b4e9e 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py @@ -5,7 +5,7 @@ from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def sub_final_format(state: ResearchQAState) -> dict[str, Any]: +def final_format(state: ResearchQAState) -> dict[str, Any]: """ Create the final output for the QA subgraph """ diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py b/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py index dbd478fb796..48faed35adc 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py @@ -10,7 +10,7 @@ from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def sub_generate(state: ResearchQAState) -> dict[str, Any]: +def generate(state: ResearchQAState) -> dict[str, Any]: """ Generate answer diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py b/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py index 5dd77af213c..40fe7c70427 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py @@ -10,7 +10,7 @@ from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def sub_qa_check(state: ResearchQAState) -> dict[str, Any]: +def qa_check(state: ResearchQAState) -> dict[str, Any]: """ Check whether the final output satisfies the original user question diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py b/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py index f7cb836842a..f0ac9fbdabd 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py @@ -11,7 +11,7 @@ from danswer.llm.interfaces import LLM -def sub_rewrite(state: ResearchQAState) -> dict[str, Any]: +def rewrite(state: ResearchQAState) -> dict[str, Any]: """ Transform the initial question into more suitable search queries. diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py b/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py index 49bfcb62d17..c24f5a36c03 100644 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py +++ b/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py @@ -10,7 +10,7 @@ from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def sub_verifier(state: VerifierState) -> dict[str, Any]: +def verifier(state: VerifierState) -> dict[str, Any]: """ Check whether the document is relevant for the original user question diff --git a/backend/danswer/agent_search/primary_state.py b/backend/danswer/agent_search/primary_state.py new file mode 100644 index 00000000000..4f5fd7de6cd --- /dev/null +++ b/backend/danswer/agent_search/primary_state.py @@ -0,0 +1,16 @@ +from datetime import datetime +from typing import TypedDict + +from sqlalchemy.orm import Session + +from danswer.llm.interfaces import LLM + + +class PrimaryState(TypedDict): + agent_search_start_time: datetime + original_question: str + primary_llm: LLM + fast_llm: LLM + # a single session for the entire agent search + # is fine if we are only reading + db_session: Session diff --git a/backend/danswer/agent_search/test.py b/backend/danswer/agent_search/test.py new file mode 100644 index 00000000000..1f760e4dc90 --- /dev/null +++ b/backend/danswer/agent_search/test.py @@ -0,0 +1,120 @@ +from typing import Annotated +from typing import Literal +from typing import TypedDict + +from dotenv import load_dotenv +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph +from langgraph.types import Send + + +def unique_concat(a: list[str], b: list[str]) -> list[str]: + combined = a + b + return list(set(combined)) + + +load_dotenv(".vscode/.env") + + +class InputState(TypedDict): + user_input: str + # str_arr: list[str] + + +class OutputState(TypedDict): + graph_output: str + + +class SharedState(TypedDict): + llm: int + + +class OverallState(TypedDict): + foo: str + user_input: str + str_arr: Annotated[list[str], unique_concat] + # str_arr: list[str] + + +class PrivateState(TypedDict): + foo: str + bar: str + + +def conditional_edge_3(state: PrivateState) -> Literal["node_4"]: + print(f"conditional_edge_3: {state}") + return Send( + "node_4", + state, + ) + + +def node_1(state: OverallState): + print(f"node_1: {state}") + return { + "foo": state["user_input"] + " name", + "user_input": state["user_input"], + "str_arr": ["a", "b", "c"], + } + + +def node_2(state: OverallState): + print(f"node_2: {state}") + return { + "foo": "foo", + "bar": "bar", + "test1": "test1", + "str_arr": ["a", "d", "e", "f"], + } + + +def node_3(state: PrivateState): + print(f"node_3: {state}") + return {"bar": state["bar"] + " Lance"} + + +def node_4(state: PrivateState): + print(f"node_4: {state}") + return { + "foo": state["bar"] + " more bar", + } + + +def node_5(state: OverallState): + print(f"node_5: {state}") + updated_aggregate = [item for item in state["str_arr"] if "b" not in item] + print(f"updated_aggregate: {updated_aggregate}") + return {"str_arr": updated_aggregate} + + +builder = StateGraph( + state_schema=OverallState, + # input=InputState, + # output=OutputState +) +builder.add_node("node_1", node_1) +builder.add_node("node_2", node_2) +builder.add_node("node_3", node_3) +builder.add_node("node_4", node_4) +builder.add_node("node_5", node_5) +builder.add_edge(START, "node_1") +builder.add_edge("node_1", "node_2") +builder.add_edge("node_2", "node_3") +builder.add_conditional_edges( + source="node_3", + path=conditional_edge_3, +) +builder.add_edge("node_4", "node_5") +builder.add_edge("node_5", END) +graph = builder.compile() +# output = graph.invoke( +# {"user_input":"My"}, +# stream_mode="values", +# ) +for chunk in graph.stream( + {"user_input": "My"}, + stream_mode="debug", +): + print() + print(chunk) diff --git a/backend/danswer/agent_search/util_sub_graphs/collect_docs.py b/backend/danswer/agent_search/util_sub_graphs/collect_docs.py new file mode 100644 index 00000000000..7666411ac95 --- /dev/null +++ b/backend/danswer/agent_search/util_sub_graphs/collect_docs.py @@ -0,0 +1,44 @@ +import operator +from datetime import datetime +from typing import Annotated + +from danswer.agent_search.primary_state import PrimaryState +from danswer.context.search.models import InferenceSection + + +class CollectDocsInput(PrimaryState): + sub_question_retrieval_docs: Annotated[list[InferenceSection], operator.add] + sub_question_str: str + graph_start_time: datetime + + +class CollectDocsOutput(PrimaryState): + deduped_retrieval_docs: list[InferenceSection] + + +def collect_docs(state: CollectDocsInput) -> CollectDocsOutput: + """ + Dedupe the retrieved docs. + """ + + sub_question_retrieval_docs = state["sub_question_retrieval_docs"] + + print(f"Number of docs from steps: {len(sub_question_retrieval_docs)}") + dedupe_docs: list[InferenceSection] = [] + for retrieval_doc in sub_question_retrieval_docs: + if not any( + retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id + for doc in dedupe_docs + ): + dedupe_docs.append(retrieval_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + return state + # return CollectDocsOutput( + # deduped_retrieval_docs=dedupe_docs, + # ) + return { + "deduped_retrieval_docs": dedupe_docs, + "test_var": "test_var", + } diff --git a/backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py b/backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py new file mode 100644 index 00000000000..b03348d6c30 --- /dev/null +++ b/backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py @@ -0,0 +1,65 @@ +from datetime import datetime +from typing import TypedDict + +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.context.search.models import InferenceSection + + +class DedupeRetrievedDocsInput(TypedDict): + pre_dedupe_docs: list[InferenceSection] + + +class DedupeRetrievedDocsOutput(TypedDict): + deduped_docs: list[InferenceSection] + + +def dedupe_retrieved_docs(state: DedupeRetrievedDocsInput) -> DedupeRetrievedDocsOutput: + """ + Dedupe the retrieved docs. + """ + datetime.now() + + pre_dedupe_docs = state["pre_dedupe_docs"] + + print(f"Number of docs from steps: {len(pre_dedupe_docs)}") + dedupe_docs: list[InferenceSection] = [] + for pre_dedupe_doc in pre_dedupe_docs: + if not any( + pre_dedupe_doc.center_chunk.document_id == doc.center_chunk.document_id + for doc in dedupe_docs + ): + dedupe_docs.append(pre_dedupe_doc) + + print(f"Number of deduped docs: {len(dedupe_docs)}") + + return DedupeRetrievedDocsOutput( + deduped_docs=dedupe_docs, + ) + + +def build_dedupe_retrieved_docs_graph() -> StateGraph: + dedupe_retrieved_docs_graph = StateGraph( + state_schema=DedupeRetrievedDocsInput, + input=DedupeRetrievedDocsInput, + output=DedupeRetrievedDocsOutput, + ) + + dedupe_retrieved_docs_graph.add_node( + node="dedupe_retrieved_docs", + action=dedupe_retrieved_docs, + ) + + dedupe_retrieved_docs_graph.add_edge( + start_key=START, + end_key="dedupe_retrieved_docs", + ) + + dedupe_retrieved_docs_graph.add_edge( + start_key="dedupe_retrieved_docs", + end_key=END, + ) + + return dedupe_retrieved_docs_graph From ad4df04159d537a10446a2caf8aa876f3a276ab3 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 12 Dec 2024 15:20:39 -0800 Subject: [PATCH 07/10] done with exanded retrieval --- .../answer_query/graph_builder.py | 0 .../answer_query/nodes/generate.py | 0 .../answer_query/nodes/qa_check.py | 0 .../agent_search/answer_query/states.py | 0 .../agent_search/expanded_retrieval/edges.py | 59 ++++++++++ .../expanded_retrieval/graph_builder.py | 95 ++++++++++++++++ .../nodes/collect_retrieved_docs.py | 27 +++++ .../expanded_retrieval/nodes/doc_reranking.py | 11 ++ .../expanded_retrieval/nodes/doc_retrieval.py | 104 ++++++++++++++++++ .../nodes/doc_verification.py | 60 ++++++++++ .../expanded_retrieval/prompts.py | 0 .../agent_search/expanded_retrieval/states.py | 43 ++++++++ backend/danswer/agent_search/primary_state.py | 11 +- backend/danswer/agent_search/run_graph.py | 7 +- backend/requirements/default.txt | 2 +- 15 files changed, 413 insertions(+), 6 deletions(-) create mode 100644 backend/danswer/agent_search/answer_query/graph_builder.py create mode 100644 backend/danswer/agent_search/answer_query/nodes/generate.py create mode 100644 backend/danswer/agent_search/answer_query/nodes/qa_check.py create mode 100644 backend/danswer/agent_search/answer_query/states.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/edges.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/graph_builder.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/nodes/collect_retrieved_docs.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/nodes/doc_reranking.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/nodes/doc_verification.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/prompts.py create mode 100644 backend/danswer/agent_search/expanded_retrieval/states.py diff --git a/backend/danswer/agent_search/answer_query/graph_builder.py b/backend/danswer/agent_search/answer_query/graph_builder.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/answer_query/nodes/generate.py b/backend/danswer/agent_search/answer_query/nodes/generate.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/answer_query/nodes/qa_check.py b/backend/danswer/agent_search/answer_query/nodes/qa_check.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/answer_query/states.py b/backend/danswer/agent_search/answer_query/states.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/expanded_retrieval/edges.py b/backend/danswer/agent_search/expanded_retrieval/edges.py new file mode 100644 index 00000000000..2524ec66ea0 --- /dev/null +++ b/backend/danswer/agent_search/expanded_retrieval/edges.py @@ -0,0 +1,59 @@ +from typing import Literal + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs +from langgraph.types import Send + +from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from danswer.llm.interfaces import LLM + + +def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> Literal["doc_retrieval"]: + """ + Transform the initial question into more suitable search queries. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + print(f"parallel_retrieval_edge state: {state.keys()}") + # messages = state["base_answer_messages"] + question = state["query_to_expand"] + llm: LLM = state["fast_llm"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + llm_response_list = list( + llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + print(f"llm_response: {llm_response}") + + rewritten_queries = llm_response.split("\n") + + print(f"rewritten_queries: {rewritten_queries}") + + return [ + Send( + "doc_retrieval", + RetrieveInput(query_to_retrieve=query, **state), + ) + for query in rewritten_queries + ] + + +def conditionally_rerank_edge(state: ExpandedRetrievalState) -> bool: + print(f"conditionally_rerank_edge state: {state.keys()}") + return bool(state["search_request"].rerank_settings) diff --git a/backend/danswer/agent_search/expanded_retrieval/graph_builder.py b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py new file mode 100644 index 00000000000..b17e736b6e0 --- /dev/null +++ b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py @@ -0,0 +1,95 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.agent_search.expanded_retrieval.edges import conditionally_rerank_edge +from danswer.agent_search.expanded_retrieval.edges import parallel_retrieval_edge +from danswer.agent_search.expanded_retrieval.nodes.collect_retrieved_docs import ( + kick_off_verification, +) +from danswer.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking +from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval +from danswer.agent_search.expanded_retrieval.nodes.doc_verification import ( + doc_verification, +) +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def qa_graph_builder() -> StateGraph: + graph = StateGraph(ExpandedRetrievalState) + + graph.add_node( + node="doc_retrieval", + action=doc_retrieval, + ) + graph.add_node( + node="kick_off_verification", + action=kick_off_verification, + ) + graph.add_node( + node="doc_verification", + action=doc_verification, + ) + graph.add_node( + node="doc_reranking", + action=doc_reranking, + ) + + graph.add_conditional_edges( + source=START, + path=parallel_retrieval_edge, + ) + graph.add_edge( + start_key="doc_retrieval", + end_key="kick_off_verification", + ) + graph.add_conditional_edges( + source="doc_verification", + path=conditionally_rerank_edge, + path_map={ + True: "doc_reranking", + False: END, + }, + ) + graph.add_edge( + start_key="doc_reranking", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from danswer.db.engine import get_session_context_manager + from danswer.llm.factory import get_default_llms + from danswer.context.search.models import SearchRequest + + graph = qa_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="Who made Excel and what other products did they make?", + ) + with get_session_context_manager() as db_session: + inputs = ExpandedRetrievalInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + query_to_expand="Who made Excel?", + ) + for thing in compiled_graph.stream(inputs, debug=True): + print(thing) + # output = compiled_graph.invoke(inputs) + # print("\nOUTPUT:") + # print(output.keys()) + # for key, value in output.items(): + # if key in [ + # "sub_question_answer", + # "sub_question_str", + # "sub_qas", + # "initial_sub_qas", + # "sub_question_answer", + # ]: + # print(f"{key}: {value}") diff --git a/backend/danswer/agent_search/expanded_retrieval/nodes/collect_retrieved_docs.py b/backend/danswer/agent_search/expanded_retrieval/nodes/collect_retrieved_docs.py new file mode 100644 index 00000000000..4bd882a5436 --- /dev/null +++ b/backend/danswer/agent_search/expanded_retrieval/nodes/collect_retrieved_docs.py @@ -0,0 +1,27 @@ +from typing import Literal + +from langgraph.types import Command +from langgraph.types import Send + +from danswer.agent_search.expanded_retrieval.nodes.doc_verification import ( + DocVerificationInput, +) +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def kick_off_verification( + state: ExpandedRetrievalState, +) -> Command[Literal["doc_verification"]]: + print(f"kick_off_verification state: {state.keys()}") + + documents = state["retrieved_documents"] + return Command( + update={}, + goto=[ + Send( + node="doc_verification", + arg=DocVerificationInput(doc_to_verify=doc, **state), + ) + for doc in documents + ], + ) diff --git a/backend/danswer/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_reranking.py new file mode 100644 index 00000000000..a92c6a59093 --- /dev/null +++ b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -0,0 +1,11 @@ +from danswer.agent_search.expanded_retrieval.states import DocRerankingOutput +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput: + print(f"doc_reranking state: {state.keys()}") + + verified_documents = state["verified_documents"] + reranked_documents = verified_documents + + return DocRerankingOutput(reranked_documents=reranked_documents) diff --git a/backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py new file mode 100644 index 00000000000..05c92d0c2c8 --- /dev/null +++ b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -0,0 +1,104 @@ +import random +import time +from datetime import datetime +from unittest.mock import MagicMock + +from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from danswer.configs.constants import DocumentSource +from danswer.context.search.models import InferenceChunk +from danswer.context.search.models import InferenceSection + + +def create_mock_inference_section() -> MagicMock: + # Create a mock InferenceChunk first + mock_chunk = MagicMock(spec=InferenceChunk) + mock_chunk.document_id = f"test_doc_id_{random.randint(1, 1000)}" + mock_chunk.source_type = DocumentSource.FILE + mock_chunk.semantic_identifier = "test_semantic_id" + mock_chunk.title = "Test Title" + mock_chunk.boost = 1 + mock_chunk.recency_bias = 1.0 + mock_chunk.score = 0.95 + mock_chunk.hidden = False + mock_chunk.is_relevant = True + mock_chunk.relevance_explanation = "Test relevance" + mock_chunk.metadata = {"key": "value"} + mock_chunk.match_highlights = ["test highlight"] + mock_chunk.updated_at = datetime.now() + mock_chunk.primary_owners = ["owner1"] + mock_chunk.secondary_owners = ["owner2"] + mock_chunk.large_chunk_reference_ids = [1, 2] + mock_chunk.chunk_id = 1 + mock_chunk.content = "Test content" + mock_chunk.blurb = "Test blurb" + + # Create the InferenceSection mock + mock_section = MagicMock(spec=InferenceSection) + mock_section.center_chunk = mock_chunk + mock_section.chunks = [mock_chunk] + mock_section.combined_content = "Test combined content" + + return mock_section + + +def get_mock_inference_sections() -> list[InferenceSection]: + """Returns a list of mock InferenceSections for testing""" + return [create_mock_inference_section()] + + +class RetrieveInput(ExpandedRetrievalState): + query_to_retrieve: str + + +def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: + # def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print(f"doc_retrieval state: {state.keys()}") + + state["query_to_retrieve"] + + documents: list[InferenceSection] = [] + state["primary_llm"] + state["fast_llm"] + # db_session = state["db_session"] + + # from danswer.db.engine import get_session_context_manager + # with get_session_context_manager() as db_session1: + # documents = SearchPipeline( + # search_request=SearchRequest( + # query=query_to_retrieve, + # ), + # user=None, + # llm=llm, + # fast_llm=fast_llm, + # db_session=db_session1, + # ).reranked_sections + + time.sleep(random.random() * 10) + + documents = get_mock_inference_sections() + + print(f"documents: {documents}") + + # return Command( + # update={"retrieved_documents": documents}, + # goto=Send( + # node="doc_verification", + # arg=DocVerificationInput( + # doc_to_verify=documents, + # **state + # ), + # ), + # ) + return DocRetrievalOutput( + retrieved_documents=documents, + ) diff --git a/backend/danswer/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_verification.py new file mode 100644 index 00000000000..ab2907efd13 --- /dev/null +++ b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -0,0 +1,60 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.expanded_retrieval.states import DocVerificationOutput +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from danswer.agent_search.shared_graph_utils.models import BinaryDecision +from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from danswer.context.search.models import InferenceSection + + +class DocVerificationInput(ExpandedRetrievalState, total=True): + doc_to_verify: InferenceSection + + +def doc_verification(state: DocVerificationInput) -> DocVerificationOutput: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print(f"doc_verification state: {state.keys()}") + + original_query = state["search_request"].query + doc_to_verify = state["doc_to_verify"] + document_content = doc_to_verify.combined_content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=original_query, document_content=document_content + ) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + response_string = merge_message_runs(response, chunk_separator="")[0].content + # Convert string response to proper dictionary format + decision_dict = {"decision": response_string.lower()} + formatted_response = BinaryDecision.model_validate(decision_dict) + + print(f"Verdict: {formatted_response.decision}") + + verified_documents = [] + if formatted_response.decision == "yes": + verified_documents.append(doc_to_verify) + + return DocVerificationOutput( + verified_documents=verified_documents, + ) diff --git a/backend/danswer/agent_search/expanded_retrieval/prompts.py b/backend/danswer/agent_search/expanded_retrieval/prompts.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/agent_search/expanded_retrieval/states.py b/backend/danswer/agent_search/expanded_retrieval/states.py new file mode 100644 index 00000000000..a51985964b8 --- /dev/null +++ b/backend/danswer/agent_search/expanded_retrieval/states.py @@ -0,0 +1,43 @@ +from typing import Annotated +from typing import TypedDict + +from danswer.agent_search.primary_state import PrimaryState +from danswer.context.search.models import InferenceSection +from danswer.llm.answering.prune_and_merge import _merge_sections + + +def dedup_inference_sections( + list1: list[InferenceSection], list2: list[InferenceSection] +) -> list[InferenceSection]: + deduped = _merge_sections(list1 + list2) + return deduped + + +class DocRetrievalOutput(TypedDict, total=False): + retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocVerificationOutput(TypedDict, total=False): + verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocRerankingOutput(TypedDict, total=False): + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class ExpandedRetrievalState( + PrimaryState, + DocRetrievalOutput, + DocVerificationOutput, + DocRerankingOutput, + total=True, +): + query_to_expand: str + + +class ExpandedRetrievalInput(PrimaryState, total=True): + query_to_expand: str + + +class ExpandedRetrievalOutput(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/danswer/agent_search/primary_state.py b/backend/danswer/agent_search/primary_state.py index 4f5fd7de6cd..23898e774cb 100644 --- a/backend/danswer/agent_search/primary_state.py +++ b/backend/danswer/agent_search/primary_state.py @@ -1,16 +1,19 @@ -from datetime import datetime from typing import TypedDict from sqlalchemy.orm import Session +from danswer.context.search.models import SearchRequest from danswer.llm.interfaces import LLM -class PrimaryState(TypedDict): - agent_search_start_time: datetime - original_question: str +class PrimaryState(TypedDict, total=False): + search_request: SearchRequest primary_llm: LLM fast_llm: LLM # a single session for the entire agent search # is fine if we are only reading db_session: Session + + +class InputState(PrimaryState, total=True): + pass diff --git a/backend/danswer/agent_search/run_graph.py b/backend/danswer/agent_search/run_graph.py index 02c14a64438..6cdd0653778 100644 --- a/backend/danswer/agent_search/run_graph.py +++ b/backend/danswer/agent_search/run_graph.py @@ -12,7 +12,7 @@ def run_graph( graph = build_core_graph() inputs = { - "original_question": query, + "original_query": query, "messages": [], "tools": tools, "llm": llm, @@ -20,3 +20,8 @@ def run_graph( compiled_graph = graph.compile() output = compiled_graph.invoke(input=inputs) yield from output + + +if __name__ == "__main__": + pass + # run_graph("What is the capital of France?", llm, []) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 5054c641480..8a61115147b 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -31,7 +31,7 @@ langchain-core==0.3.20 langchain-openai==0.2.9 langchain-text-splitters==0.3.2 langchainhub==0.1.21 -langgraph==0.2.53 +langgraph==0.2.59 langgraph-checkpoint==2.0.5 langgraph-sdk==0.1.36 litellm==1.53.1 From f4e8ac1dde32df6f2b38eb70bd1f642a58d1b3c9 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 12 Dec 2024 15:54:13 -0800 Subject: [PATCH 08/10] 2 graphs down --- .../answer_query/graph_builder.py | 72 +++++++++++++++++++ .../answer_query/nodes/generate.py | 0 .../answer_query/nodes/qa_check.py | 31 ++++++++ .../answer_query/nodes/qa_generation.py | 32 +++++++++ .../agent_search/answer_query/states.py | 43 +++++++++++ .../agent_search/expanded_retrieval/edges.py | 2 +- .../expanded_retrieval/graph_builder.py | 16 +---- .../agent_search/expanded_retrieval/states.py | 4 +- .../util_sub_graphs/collect_docs.py | 44 ------------ .../util_sub_graphs/dedupe_retrieved_docs.py | 65 ----------------- 10 files changed, 183 insertions(+), 126 deletions(-) delete mode 100644 backend/danswer/agent_search/answer_query/nodes/generate.py create mode 100644 backend/danswer/agent_search/answer_query/nodes/qa_generation.py delete mode 100644 backend/danswer/agent_search/util_sub_graphs/collect_docs.py delete mode 100644 backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py diff --git a/backend/danswer/agent_search/answer_query/graph_builder.py b/backend/danswer/agent_search/answer_query/graph_builder.py index e69de29bb2d..eafa871900d 100644 --- a/backend/danswer/agent_search/answer_query/graph_builder.py +++ b/backend/danswer/agent_search/answer_query/graph_builder.py @@ -0,0 +1,72 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.agent_search.answer_query.nodes.qa_check import qa_check +from danswer.agent_search.answer_query.nodes.qa_generation import qa_generation +from danswer.agent_search.answer_query.states import AnswerQueryInput +from danswer.agent_search.answer_query.states import AnswerQueryState +from danswer.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) + + +def qa_graph_builder() -> StateGraph: + graph = StateGraph(AnswerQueryState) + + expanded_retrieval_graph = expanded_retrieval_graph_builder() + compiled_expanded_retrieval_graph = expanded_retrieval_graph.compile() + graph.add_node( + node="compiled_expanded_retrieval_graph", + action=compiled_expanded_retrieval_graph, + ) + graph.add_node( + node="qa_check", + action=qa_check, + ) + graph.add_node( + node="qa_generation", + action=qa_generation, + ) + + graph.add_edge( + start_key=START, + end_key="compiled_expanded_retrieval_graph", + ) + graph.add_edge( + start_key="compiled_expanded_retrieval_graph", + end_key="qa_generation", + ) + graph.add_edge( + start_key="qa_generation", + end_key="qa_check", + ) + graph.add_edge( + start_key="qa_check", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from danswer.db.engine import get_session_context_manager + from danswer.llm.factory import get_default_llms + from danswer.context.search.models import SearchRequest + + graph = qa_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="Who made Excel and what other products did they make?", + ) + with get_session_context_manager() as db_session: + inputs = AnswerQueryInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + query_to_answer="Who made Excel?", + ) + for thing in compiled_graph.stream(inputs, debug=True): + print(thing) diff --git a/backend/danswer/agent_search/answer_query/nodes/generate.py b/backend/danswer/agent_search/answer_query/nodes/generate.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/backend/danswer/agent_search/answer_query/nodes/qa_check.py b/backend/danswer/agent_search/answer_query/nodes/qa_check.py index e69de29bb2d..e6dc96eb1ab 100644 --- a/backend/danswer/agent_search/answer_query/nodes/qa_check.py +++ b/backend/danswer/agent_search/answer_query/nodes/qa_check.py @@ -0,0 +1,31 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.answer_query.states import AnswerQueryState +from danswer.agent_search.answer_query.states import QACheckOutput +from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT + + +def qa_check(state: AnswerQueryState) -> QACheckOutput: + msg = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format( + question=state["search_request"].query, + base_answer=state["answer"], + ) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + # structured_response_format=None, + ) + ) + + response_str = merge_message_runs(response, chunk_separator="")[0].content + + return QACheckOutput( + answer_quality=response_str, + ) diff --git a/backend/danswer/agent_search/answer_query/nodes/qa_generation.py b/backend/danswer/agent_search/answer_query/nodes/qa_generation.py new file mode 100644 index 00000000000..858c3be4c68 --- /dev/null +++ b/backend/danswer/agent_search/answer_query/nodes/qa_generation.py @@ -0,0 +1,32 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from danswer.agent_search.answer_query.states import AnswerQueryState +from danswer.agent_search.answer_query.states import QAGenerationOutput +from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs + + +def qa_generation(state: AnswerQueryState) -> QAGenerationOutput: + query = state["search_request"].query + docs = state["documents"] + + print(f"Number of verified retrieval docs: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs)) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + answer_str = merge_message_runs(response, chunk_separator="")[0].content + return QAGenerationOutput( + answer=answer_str, + ) diff --git a/backend/danswer/agent_search/answer_query/states.py b/backend/danswer/agent_search/answer_query/states.py index e69de29bb2d..240f53698e3 100644 --- a/backend/danswer/agent_search/answer_query/states.py +++ b/backend/danswer/agent_search/answer_query/states.py @@ -0,0 +1,43 @@ +from typing import Annotated +from typing import TypedDict + +from danswer.agent_search.primary_state import PrimaryState +from danswer.context.search.models import InferenceSection +from danswer.llm.answering.prune_and_merge import _merge_sections + + +def dedup_inference_sections( + list1: list[InferenceSection], list2: list[InferenceSection] +) -> list[InferenceSection]: + deduped = _merge_sections(list1 + list2) + return deduped + + +class QACheckOutput(TypedDict, total=False): + answer_quality: bool + + +class QAGenerationOutput(TypedDict, total=False): + answer: str + + +class ExpandedRetrievalOutput(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class AnswerQueryState( + PrimaryState, + QACheckOutput, + QAGenerationOutput, + ExpandedRetrievalOutput, + total=True, +): + query_to_answer: str + + +class AnswerQueryInput(PrimaryState, total=True): + query_to_answer: str + + +class AnswerQueryOutput(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/danswer/agent_search/expanded_retrieval/edges.py b/backend/danswer/agent_search/expanded_retrieval/edges.py index 2524ec66ea0..5aa8357ce5b 100644 --- a/backend/danswer/agent_search/expanded_retrieval/edges.py +++ b/backend/danswer/agent_search/expanded_retrieval/edges.py @@ -24,7 +24,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> Literal["doc_retri print(f"parallel_retrieval_edge state: {state.keys()}") # messages = state["base_answer_messages"] - question = state["query_to_expand"] + question = state["query_to_answer"] llm: LLM = state["fast_llm"] msg = [ diff --git a/backend/danswer/agent_search/expanded_retrieval/graph_builder.py b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py index b17e736b6e0..ca40bbeb235 100644 --- a/backend/danswer/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py @@ -16,7 +16,7 @@ from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState -def qa_graph_builder() -> StateGraph: +def expanded_retrieval_graph_builder() -> StateGraph: graph = StateGraph(ExpandedRetrievalState) graph.add_node( @@ -65,7 +65,7 @@ def qa_graph_builder() -> StateGraph: from danswer.llm.factory import get_default_llms from danswer.context.search.models import SearchRequest - graph = qa_graph_builder() + graph = expanded_retrieval_graph_builder() compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( @@ -81,15 +81,3 @@ def qa_graph_builder() -> StateGraph: ) for thing in compiled_graph.stream(inputs, debug=True): print(thing) - # output = compiled_graph.invoke(inputs) - # print("\nOUTPUT:") - # print(output.keys()) - # for key, value in output.items(): - # if key in [ - # "sub_question_answer", - # "sub_question_str", - # "sub_qas", - # "initial_sub_qas", - # "sub_question_answer", - # ]: - # print(f"{key}: {value}") diff --git a/backend/danswer/agent_search/expanded_retrieval/states.py b/backend/danswer/agent_search/expanded_retrieval/states.py index a51985964b8..823ffaba41c 100644 --- a/backend/danswer/agent_search/expanded_retrieval/states.py +++ b/backend/danswer/agent_search/expanded_retrieval/states.py @@ -32,11 +32,11 @@ class ExpandedRetrievalState( DocRerankingOutput, total=True, ): - query_to_expand: str + query_to_answer: str class ExpandedRetrievalInput(PrimaryState, total=True): - query_to_expand: str + query_to_answer: str class ExpandedRetrievalOutput(TypedDict): diff --git a/backend/danswer/agent_search/util_sub_graphs/collect_docs.py b/backend/danswer/agent_search/util_sub_graphs/collect_docs.py deleted file mode 100644 index 7666411ac95..00000000000 --- a/backend/danswer/agent_search/util_sub_graphs/collect_docs.py +++ /dev/null @@ -1,44 +0,0 @@ -import operator -from datetime import datetime -from typing import Annotated - -from danswer.agent_search.primary_state import PrimaryState -from danswer.context.search.models import InferenceSection - - -class CollectDocsInput(PrimaryState): - sub_question_retrieval_docs: Annotated[list[InferenceSection], operator.add] - sub_question_str: str - graph_start_time: datetime - - -class CollectDocsOutput(PrimaryState): - deduped_retrieval_docs: list[InferenceSection] - - -def collect_docs(state: CollectDocsInput) -> CollectDocsOutput: - """ - Dedupe the retrieved docs. - """ - - sub_question_retrieval_docs = state["sub_question_retrieval_docs"] - - print(f"Number of docs from steps: {len(sub_question_retrieval_docs)}") - dedupe_docs: list[InferenceSection] = [] - for retrieval_doc in sub_question_retrieval_docs: - if not any( - retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id - for doc in dedupe_docs - ): - dedupe_docs.append(retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return state - # return CollectDocsOutput( - # deduped_retrieval_docs=dedupe_docs, - # ) - return { - "deduped_retrieval_docs": dedupe_docs, - "test_var": "test_var", - } diff --git a/backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py b/backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py deleted file mode 100644 index b03348d6c30..00000000000 --- a/backend/danswer/agent_search/util_sub_graphs/dedupe_retrieved_docs.py +++ /dev/null @@ -1,65 +0,0 @@ -from datetime import datetime -from typing import TypedDict - -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph - -from danswer.context.search.models import InferenceSection - - -class DedupeRetrievedDocsInput(TypedDict): - pre_dedupe_docs: list[InferenceSection] - - -class DedupeRetrievedDocsOutput(TypedDict): - deduped_docs: list[InferenceSection] - - -def dedupe_retrieved_docs(state: DedupeRetrievedDocsInput) -> DedupeRetrievedDocsOutput: - """ - Dedupe the retrieved docs. - """ - datetime.now() - - pre_dedupe_docs = state["pre_dedupe_docs"] - - print(f"Number of docs from steps: {len(pre_dedupe_docs)}") - dedupe_docs: list[InferenceSection] = [] - for pre_dedupe_doc in pre_dedupe_docs: - if not any( - pre_dedupe_doc.center_chunk.document_id == doc.center_chunk.document_id - for doc in dedupe_docs - ): - dedupe_docs.append(pre_dedupe_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return DedupeRetrievedDocsOutput( - deduped_docs=dedupe_docs, - ) - - -def build_dedupe_retrieved_docs_graph() -> StateGraph: - dedupe_retrieved_docs_graph = StateGraph( - state_schema=DedupeRetrievedDocsInput, - input=DedupeRetrievedDocsInput, - output=DedupeRetrievedDocsOutput, - ) - - dedupe_retrieved_docs_graph.add_node( - node="dedupe_retrieved_docs", - action=dedupe_retrieved_docs, - ) - - dedupe_retrieved_docs_graph.add_edge( - start_key=START, - end_key="dedupe_retrieved_docs", - ) - - dedupe_retrieved_docs_graph.add_edge( - start_key="dedupe_retrieved_docs", - end_key=END, - ) - - return dedupe_retrieved_docs_graph From 24525ca3d846d13e3b1b23118f408b37907f125a Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 13 Dec 2024 12:18:16 -0800 Subject: [PATCH 09/10] all done --- .../answer_query/graph_builder.py | 70 ++++++--- .../nodes/{qa_check.py => answer_check.py} | 3 +- ...{qa_generation.py => answer_generation.py} | 4 +- .../answer_query/nodes/format_answer.py | 16 ++ .../agent_search/answer_query/states.py | 20 +-- .../agent_search/core_qa_graph/edges.py | 40 ----- .../core_qa_graph/graph_builder.py | 128 ---------------- .../nodes/combine_retrieved_docs.py | 35 ----- .../core_qa_graph/nodes/custom_retrieve.py | 52 ------- .../agent_search/core_qa_graph/nodes/dummy.py | 24 --- .../core_qa_graph/nodes/final_format.py | 22 --- .../core_qa_graph/nodes/generate.py | 55 ------- .../core_qa_graph/nodes/qa_check.py | 51 ------ .../core_qa_graph/nodes/rewrite.py | 62 -------- .../core_qa_graph/nodes/verifier.py | 64 -------- .../agent_search/core_qa_graph/states.py | 90 ----------- .../{primary_state.py => core_state.py} | 4 - .../__init__.py => deep_answer/edges.py} | 0 .../graph_builder.py} | 0 .../nodes/answer_generation.py} | 77 +++++++--- .../nodes/deep_decomp.py} | 4 +- .../nodes/entity_term_extraction.py | 19 +-- .../nodes/sub_qa_level_aggregator.py | 13 +- .../deep_answer/nodes/sub_qa_manager.py | 19 +++ .../__init__.py => deep_answer/states.py} | 0 .../agent_search/deep_qa_graph/edges.py | 41 ----- .../deep_qa_graph/graph_builder.py | 90 ----------- .../nodes/combine_retrieved_docs.py | 31 ---- .../deep_qa_graph/nodes/custom_retrieve.py | 33 ---- .../agent_search/deep_qa_graph/nodes/dummy.py | 21 --- .../deep_qa_graph/nodes/final_format.py | 31 ---- .../deep_qa_graph/nodes/generate.py | 56 ------- .../deep_qa_graph/nodes/qa_check.py | 57 ------- .../deep_qa_graph/nodes/rewrite.py | 64 -------- .../deep_qa_graph/nodes/verifier.py | 59 ------- .../agent_search/deep_qa_graph/prompts.py | 13 -- .../agent_search/deep_qa_graph/states.py | 64 -------- .../agent_search/expanded_retrieval/edges.py | 19 +-- .../expanded_retrieval/graph_builder.py | 26 +++- .../expanded_retrieval/nodes/doc_retrieval.py | 93 +++-------- ...rieved_docs.py => verification_kickoff.py} | 4 +- .../agent_search/expanded_retrieval/states.py | 11 +- backend/danswer/agent_search/main/edges.py | 61 ++++++++ .../agent_search/main/graph_builder.py | 99 ++++++++++++ .../agent_search/main/nodes/base_decomp.py | 31 ++++ .../main/nodes/generate_initial_answer.py | 51 ++++++ backend/danswer/agent_search/main/states.py | 37 +++++ .../agent_search/primary_graph/edges.py | 75 --------- .../primary_graph/graph_builder.py | 145 ------------------ .../primary_graph/nodes/base_wait.py | 27 ---- .../nodes/combine_retrieved_docs.py | 36 ----- .../primary_graph/nodes/custom_retrieve.py | 52 ------- .../nodes/deep_answer_generation.py | 61 -------- .../primary_graph/nodes/dummy_start.py | 11 -- .../primary_graph/nodes/generate.py | 52 ------- .../primary_graph/nodes/generate_initial.py | 72 --------- .../primary_graph/nodes/main_decomp_base.py | 64 -------- .../primary_graph/nodes/rewrite.py | 55 ------- .../primary_graph/nodes/sub_qa_manager.py | 28 ---- .../primary_graph/nodes/verifier.py | 59 ------- .../agent_search/primary_graph/prompts.py | 86 ----------- .../agent_search/primary_graph/states.py | 73 --------- .../agent_search/shared_graph_utils/models.py | 4 - .../shared_graph_utils/operators.py | 9 ++ .../shared_graph_utils/prompts.py | 88 +++++++++++ .../agent_search/shared_graph_utils/utils.py | 14 +- backend/danswer/agent_search/test.py | 120 --------------- backend/danswer/utils/timing.py | 12 +- 68 files changed, 598 insertions(+), 2409 deletions(-) rename backend/danswer/agent_search/answer_query/nodes/{qa_check.py => answer_check.py} (88%) rename backend/danswer/agent_search/answer_query/nodes/{qa_generation.py => answer_generation.py} (89%) create mode 100644 backend/danswer/agent_search/answer_query/nodes/format_answer.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/edges.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/graph_builder.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/dummy.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/final_format.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/generate.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/nodes/verifier.py delete mode 100644 backend/danswer/agent_search/core_qa_graph/states.py rename backend/danswer/agent_search/{primary_state.py => core_state.py} (87%) rename backend/danswer/agent_search/{core_qa_graph/nodes/__init__.py => deep_answer/edges.py} (100%) rename backend/danswer/agent_search/{deep_qa_graph/nodes/__init__.py => deep_answer/graph_builder.py} (100%) rename backend/danswer/agent_search/{primary_graph/nodes/final_stuff.py => deep_answer/nodes/answer_generation.py} (57%) rename backend/danswer/agent_search/{primary_graph/nodes/decompose.py => deep_answer/nodes/deep_decomp.py} (95%) rename backend/danswer/agent_search/{primary_graph => deep_answer}/nodes/entity_term_extraction.py (53%) rename backend/danswer/agent_search/{primary_graph => deep_answer}/nodes/sub_qa_level_aggregator.py (64%) create mode 100644 backend/danswer/agent_search/deep_answer/nodes/sub_qa_manager.py rename backend/danswer/agent_search/{primary_graph/nodes/__init__.py => deep_answer/states.py} (100%) delete mode 100644 backend/danswer/agent_search/deep_qa_graph/edges.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/graph_builder.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/generate.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/prompts.py delete mode 100644 backend/danswer/agent_search/deep_qa_graph/states.py rename backend/danswer/agent_search/expanded_retrieval/nodes/{collect_retrieved_docs.py => verification_kickoff.py} (88%) create mode 100644 backend/danswer/agent_search/main/edges.py create mode 100644 backend/danswer/agent_search/main/graph_builder.py create mode 100644 backend/danswer/agent_search/main/nodes/base_decomp.py create mode 100644 backend/danswer/agent_search/main/nodes/generate_initial_answer.py create mode 100644 backend/danswer/agent_search/main/states.py delete mode 100644 backend/danswer/agent_search/primary_graph/edges.py delete mode 100644 backend/danswer/agent_search/primary_graph/graph_builder.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/base_wait.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/dummy_start.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/generate.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/generate_initial.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/rewrite.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py delete mode 100644 backend/danswer/agent_search/primary_graph/nodes/verifier.py delete mode 100644 backend/danswer/agent_search/primary_graph/prompts.py delete mode 100644 backend/danswer/agent_search/primary_graph/states.py create mode 100644 backend/danswer/agent_search/shared_graph_utils/operators.py delete mode 100644 backend/danswer/agent_search/test.py diff --git a/backend/danswer/agent_search/answer_query/graph_builder.py b/backend/danswer/agent_search/answer_query/graph_builder.py index eafa871900d..bded80549b8 100644 --- a/backend/danswer/agent_search/answer_query/graph_builder.py +++ b/backend/danswer/agent_search/answer_query/graph_builder.py @@ -2,47 +2,64 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from danswer.agent_search.answer_query.nodes.qa_check import qa_check -from danswer.agent_search.answer_query.nodes.qa_generation import qa_generation +from danswer.agent_search.answer_query.nodes.answer_check import answer_check +from danswer.agent_search.answer_query.nodes.answer_generation import answer_generation +from danswer.agent_search.answer_query.nodes.format_answer import format_answer from danswer.agent_search.answer_query.states import AnswerQueryInput +from danswer.agent_search.answer_query.states import AnswerQueryOutput from danswer.agent_search.answer_query.states import AnswerQueryState from danswer.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) -def qa_graph_builder() -> StateGraph: - graph = StateGraph(AnswerQueryState) +def answer_query_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=AnswerQueryState, + input=AnswerQueryInput, + output=AnswerQueryOutput, + ) + + ### Add nodes ### - expanded_retrieval_graph = expanded_retrieval_graph_builder() - compiled_expanded_retrieval_graph = expanded_retrieval_graph.compile() + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="expanded_retrieval_for_initial_decomp", + action=expanded_retrieval, + ) graph.add_node( - node="compiled_expanded_retrieval_graph", - action=compiled_expanded_retrieval_graph, + node="answer_check", + action=answer_check, ) graph.add_node( - node="qa_check", - action=qa_check, + node="answer_generation", + action=answer_generation, ) graph.add_node( - node="qa_generation", - action=qa_generation, + node="format_answer", + action=format_answer, ) + ### Add edges ### + graph.add_edge( start_key=START, - end_key="compiled_expanded_retrieval_graph", + end_key="expanded_retrieval_for_initial_decomp", + ) + graph.add_edge( + start_key="expanded_retrieval_for_initial_decomp", + end_key="answer_generation", ) graph.add_edge( - start_key="compiled_expanded_retrieval_graph", - end_key="qa_generation", + start_key="answer_generation", + end_key="answer_check", ) graph.add_edge( - start_key="qa_generation", - end_key="qa_check", + start_key="answer_check", + end_key="format_answer", ) graph.add_edge( - start_key="qa_check", + start_key="format_answer", end_key=END, ) @@ -54,7 +71,7 @@ def qa_graph_builder() -> StateGraph: from danswer.llm.factory import get_default_llms from danswer.context.search.models import SearchRequest - graph = qa_graph_builder() + graph = answer_query_graph_builder() compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( @@ -68,5 +85,16 @@ def qa_graph_builder() -> StateGraph: db_session=db_session, query_to_answer="Who made Excel?", ) - for thing in compiled_graph.stream(inputs, debug=True): - print(thing) + output = compiled_graph.invoke( + input=inputs, + # debug=True, + # subgraphs=True, + ) + print(output) + # for namespace, chunk in compiled_graph.stream( + # input=inputs, + # # debug=True, + # subgraphs=True, + # ): + # print(namespace) + # print(chunk) diff --git a/backend/danswer/agent_search/answer_query/nodes/qa_check.py b/backend/danswer/agent_search/answer_query/nodes/answer_check.py similarity index 88% rename from backend/danswer/agent_search/answer_query/nodes/qa_check.py rename to backend/danswer/agent_search/answer_query/nodes/answer_check.py index e6dc96eb1ab..ba9b541f2a2 100644 --- a/backend/danswer/agent_search/answer_query/nodes/qa_check.py +++ b/backend/danswer/agent_search/answer_query/nodes/answer_check.py @@ -6,7 +6,7 @@ from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT -def qa_check(state: AnswerQueryState) -> QACheckOutput: +def answer_check(state: AnswerQueryState) -> QACheckOutput: msg = [ HumanMessage( content=BASE_CHECK_PROMPT.format( @@ -20,7 +20,6 @@ def qa_check(state: AnswerQueryState) -> QACheckOutput: response = list( fast_llm.stream( prompt=msg, - # structured_response_format=None, ) ) diff --git a/backend/danswer/agent_search/answer_query/nodes/qa_generation.py b/backend/danswer/agent_search/answer_query/nodes/answer_generation.py similarity index 89% rename from backend/danswer/agent_search/answer_query/nodes/qa_generation.py rename to backend/danswer/agent_search/answer_query/nodes/answer_generation.py index 858c3be4c68..268558e08ff 100644 --- a/backend/danswer/agent_search/answer_query/nodes/qa_generation.py +++ b/backend/danswer/agent_search/answer_query/nodes/answer_generation.py @@ -7,8 +7,8 @@ from danswer.agent_search.shared_graph_utils.utils import format_docs -def qa_generation(state: AnswerQueryState) -> QAGenerationOutput: - query = state["search_request"].query +def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: + query = state["query_to_answer"] docs = state["documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/danswer/agent_search/answer_query/nodes/format_answer.py b/backend/danswer/agent_search/answer_query/nodes/format_answer.py new file mode 100644 index 00000000000..a3a06ac7347 --- /dev/null +++ b/backend/danswer/agent_search/answer_query/nodes/format_answer.py @@ -0,0 +1,16 @@ +from danswer.agent_search.answer_query.states import AnswerQueryOutput +from danswer.agent_search.answer_query.states import AnswerQueryState +from danswer.agent_search.answer_query.states import SearchAnswerResults + + +def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: + return AnswerQueryOutput( + decomp_answer_results=[ + SearchAnswerResults( + query=state["query_to_answer"], + quality=state["answer_quality"], + answer=state["answer"], + documents=state["documents"], + ) + ], + ) diff --git a/backend/danswer/agent_search/answer_query/states.py b/backend/danswer/agent_search/answer_query/states.py index 240f53698e3..bb9ac37d40e 100644 --- a/backend/danswer/agent_search/answer_query/states.py +++ b/backend/danswer/agent_search/answer_query/states.py @@ -1,20 +1,22 @@ from typing import Annotated from typing import TypedDict -from danswer.agent_search.primary_state import PrimaryState +from pydantic import BaseModel + +from danswer.agent_search.core_state import PrimaryState +from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections from danswer.context.search.models import InferenceSection -from danswer.llm.answering.prune_and_merge import _merge_sections -def dedup_inference_sections( - list1: list[InferenceSection], list2: list[InferenceSection] -) -> list[InferenceSection]: - deduped = _merge_sections(list1 + list2) - return deduped +class SearchAnswerResults(BaseModel): + query: str + answer: str + quality: str + documents: Annotated[list[InferenceSection], dedup_inference_sections] class QACheckOutput(TypedDict, total=False): - answer_quality: bool + answer_quality: str class QAGenerationOutput(TypedDict, total=False): @@ -40,4 +42,4 @@ class AnswerQueryInput(PrimaryState, total=True): class AnswerQueryOutput(TypedDict): - documents: Annotated[list[InferenceSection], dedup_inference_sections] + decomp_answer_results: list[SearchAnswerResults] diff --git a/backend/danswer/agent_search/core_qa_graph/edges.py b/backend/danswer/agent_search/core_qa_graph/edges.py deleted file mode 100644 index 3a175399c7c..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/edges.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Literal - -from langgraph.types import Send - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.primary_graph.states import VerifierState - - -def continue_to_verifier(state: BaseQAState) -> Literal["verifier"]: - # Routes each de-douped retrieved doc to the verifier step - in parallel - # Notice the 'Send()' API that takes care of the parallelization - - return [ - Send( - "verifier", - VerifierState( - document=doc, - question=state["sub_question_str"], - graph_start_time=state["graph_start_time"], - ), - ) - for doc in state["sub_question_retrieval_docs"] - ] - - -def continue_to_retrieval(state: BaseQAState) -> Literal["custom_retrieve"]: - # Routes re-written queries to the (parallel) retrieval steps - # Notice the 'Send()' API that takes care of the parallelization - rewritten_queries = state["sub_question_search_queries"].rewritten_queries - return [ - Send( - "custom_retrieve", - RetrieverState( - rewritten_query=query, - graph_start_time=state["graph_start_time"], - ), - ) - for query in rewritten_queries - ] diff --git a/backend/danswer/agent_search/core_qa_graph/graph_builder.py b/backend/danswer/agent_search/core_qa_graph/graph_builder.py deleted file mode 100644 index 985b8669f17..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/graph_builder.py +++ /dev/null @@ -1,128 +0,0 @@ -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph - -from danswer.agent_search.core_qa_graph.edges import continue_to_retrieval -from danswer.agent_search.core_qa_graph.edges import continue_to_verifier -from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import ( - custom_retrieve, -) -from danswer.agent_search.core_qa_graph.nodes.dummy import dummy -from danswer.agent_search.core_qa_graph.nodes.final_format import ( - final_format, -) -from danswer.agent_search.core_qa_graph.nodes.generate import generate -from danswer.agent_search.core_qa_graph.nodes.qa_check import qa_check -from danswer.agent_search.core_qa_graph.nodes.rewrite import rewrite -from danswer.agent_search.core_qa_graph.nodes.verifier import verifier -from danswer.agent_search.core_qa_graph.states import BaseQAOutputState -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.core_qa_graph.states import CoreQAInputState -from danswer.agent_search.util_sub_graphs.collect_docs import collect_docs -from danswer.agent_search.util_sub_graphs.dedupe_retrieved_docs import ( - build_dedupe_retrieved_docs_graph, -) - -# from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import combine_retrieved_docs - - -def build_core_qa_graph() -> StateGraph: - answers_initial = StateGraph( - state_schema=BaseQAState, - output=BaseQAOutputState, - ) - - ### Add nodes ### - answers_initial.add_node(node="dummy", action=dummy) - answers_initial.add_node(node="rewrite", action=rewrite) - answers_initial.add_node( - node="custom_retrieve", - action=custom_retrieve, - ) - # answers_initial.add_node( - # node="collect_docs", - # action=collect_docs, - # ) - build_dedupe_retrieved_docs_graph().compile() - answers_initial.add_node( - node="collect_docs", - action=collect_docs, - ) - answers_initial.add_node( - node="verifier", - action=verifier, - ) - answers_initial.add_node( - node="generate", - action=generate, - ) - answers_initial.add_node( - node="qa_check", - action=qa_check, - ) - answers_initial.add_node( - node="final_format", - action=final_format, - ) - - ### Add edges ### - answers_initial.add_edge(START, "dummy") - answers_initial.add_edge("dummy", "rewrite") - - answers_initial.add_conditional_edges( - source="rewrite", - path=continue_to_retrieval, - ) - - answers_initial.add_edge( - start_key="custom_retrieve", - end_key="collect_docs", - ) - - answers_initial.add_conditional_edges( - source="collect_docs", - path=continue_to_verifier, - ) - - answers_initial.add_edge( - start_key="verifier", - end_key="generate", - ) - - answers_initial.add_edge( - start_key="generate", - end_key="qa_check", - ) - - answers_initial.add_edge( - start_key="qa_check", - end_key="final_format", - ) - - answers_initial.add_edge( - start_key="final_format", - end_key=END, - ) - # answers_graph = answers_initial.compile() - return answers_initial - - -if __name__ == "__main__": - inputs = CoreQAInputState( - original_question="Whose music is kind of hard to easily enjoy?", - sub_question_str="Whose music is kind of hard to easily enjoy?", - ) - sub_answers_graph = build_core_qa_graph() - compiled_sub_answers = sub_answers_graph.compile() - output = compiled_sub_answers.invoke(inputs) - print("\nOUTPUT:") - print(output.keys()) - for key, value in output.items(): - if key in [ - "sub_question_answer", - "sub_question_str", - "sub_qas", - "initial_sub_qas", - "sub_question_answer", - ]: - print(f"{key}: {value}") diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py deleted file mode 100644 index bb45f083328..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/combine_retrieved_docs.py +++ /dev/null @@ -1,35 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.context.search.models import InferenceSection - - -def combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]: - """ - Dedupe the retrieved docs. - """ - node_start_time = datetime.now() - - sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] - - print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") - dedupe_docs: list[InferenceSection] = [] - for base_retrieval_doc in sub_question_base_retrieval_docs: - if not any( - base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id - for doc in dedupe_docs - ): - dedupe_docs.append(base_retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return { - "sub_question_deduped_retrieval_docs": dedupe_docs, - "log_messages": generate_log_message( - message="sub - combine_retrieved_docs (dedupe)", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py deleted file mode 100644 index 1fe4a840c99..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/custom_retrieve.py +++ /dev/null @@ -1,52 +0,0 @@ -import datetime - -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.util_sub_graphs.collect_docs import CollectDocsInput -from danswer.agent_search.util_sub_graphs.dedupe_retrieved_docs import ( - DedupeRetrievedDocsInput, -) -from danswer.context.search.models import InferenceSection -from danswer.context.search.models import SearchRequest -from danswer.context.search.pipeline import SearchPipeline -from danswer.db.engine import get_session_context_manager -from danswer.llm.factory import get_default_llms - - -def custom_retrieve(state: RetrieverState) -> DedupeRetrievedDocsInput: - """ - Retrieve documents - - Args: - state (dict): The current graph state - - Returns: - state (dict): New key added to state, documents, that contains retrieved documents - """ - print("---RETRIEVE SUB---") - - datetime.datetime.now() - - rewritten_query = state["rewritten_query"] - - # Retrieval - # TODO: add the actual retrieval, probably from search_tool.run() - documents: list[InferenceSection] = [] - llm, fast_llm = get_default_llms() - with get_session_context_manager() as db_session: - documents = SearchPipeline( - search_request=SearchRequest( - query=rewritten_query, - ), - user=None, - llm=llm, - fast_llm=fast_llm, - db_session=db_session, - ).reranked_sections - - return CollectDocsInput( - sub_question_retrieval_docs=documents, - ) - - # return DedupeRetrievedDocsInput( - # pre_dedupe_docs=documents, - # ) diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py b/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py deleted file mode 100644 index e8718838300..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/dummy.py +++ /dev/null @@ -1,24 +0,0 @@ -import datetime -from typing import Any - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def dummy(state: BaseQAState) -> dict[str, Any]: - """ - Dummy step - """ - - print("---Sub Dummy---") - - node_start_time = datetime.datetime.now() - - return { - "graph_start_time": node_start_time, - "log_messages": generate_log_message( - message="sub - dummy", - node_start_time=node_start_time, - graph_start_time=node_start_time, - ), - } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py b/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py deleted file mode 100644 index c75262d08ad..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/final_format.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any - -from danswer.agent_search.core_qa_graph.states import BaseQAState - - -def final_format(state: BaseQAState) -> dict[str, Any]: - """ - Create the final output for the QA subgraph - """ - - print("---BASE FINAL FORMAT---") - - return { - "sub_qas": [ - { - "sub_question": state["sub_question_str"], - "sub_answer": state["sub_question_answer"], - "sub_answer_check": state["sub_question_answer_check"], - } - ], - "log_messages": state["log_messages"], - } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/generate.py b/backend/danswer/agent_search/core_qa_graph/nodes/generate.py deleted file mode 100644 index 21416d7473e..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/generate.py +++ /dev/null @@ -1,55 +0,0 @@ -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.llm.factory import get_default_llms - - -def generate(state: BaseQAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---GENERATE---") - node_start_time = datetime.now() - - question = state["original_question"] - docs = state["sub_question_verified_retrieval_docs"] - - print(f"Number of verified retrieval docs: {len(docs)}") - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) - ) - ] - - # Grader - _, fast_llm = get_default_llms() - response = list( - fast_llm.stream( - prompt=msg, - # structured_response_format=None, - ) - ) - - answer_str = merge_message_runs(response, chunk_separator="")[0].content - return { - "sub_question_answer": answer_str, - "log_messages": generate_log_message( - message="base - generate", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py b/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py deleted file mode 100644 index 3226bac00c0..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py +++ /dev/null @@ -1,51 +0,0 @@ -import datetime -from typing import Any - -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.llm.factory import get_default_llms - - -def qa_check(state: BaseQAState) -> dict[str, Any]: - """ - Check if the sub-question answer is satisfactory. - - Args: - state: The current SubQAState containing the sub-question and its answer - - Returns: - dict containing the check result and log message - """ - node_start_time = datetime.datetime.now() - - msg = [ - HumanMessage( - content=BASE_CHECK_PROMPT.format( - question=state["sub_question_str"], - base_answer=state["sub_question_answer"], - ) - ) - ] - - _, fast_llm = get_default_llms() - response = list( - fast_llm.stream( - prompt=msg, - # structured_response_format=None, - ) - ) - - response_str = merge_message_runs(response, chunk_separator="")[0].content - - return { - "sub_question_answer_check": response_str, - "base_answer_messages": generate_log_message( - message="sub - qa_check", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py b/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py deleted file mode 100644 index 57118227b9e..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py +++ /dev/null @@ -1,62 +0,0 @@ -import datetime -from typing import Any - -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.llm.factory import get_default_llms - - -def rewrite(state: BaseQAState) -> dict[str, Any]: - """ - Transform the initial question into more suitable search queries. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - - print("---SUB TRANSFORM QUERY---") - - node_start_time = datetime.datetime.now() - - # messages = state["base_answer_messages"] - question = state["sub_question_str"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), - ) - ] - _, fast_llm = get_default_llms() - llm_response_list = list( - fast_llm.stream( - prompt=msg, - # structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()}, - # structured_response_format=RewrittenQueries.model_json_schema(), - ) - ) - llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content - - print(f"llm_response: {llm_response}") - - rewritten_queries = llm_response.split("\n") - - print(f"rewritten_queries: {rewritten_queries}") - - rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries) - - return { - "sub_question_search_queries": rewritten_queries, - "log_messages": generate_log_message( - message="sub - rewrite", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py b/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py deleted file mode 100644 index e5b89a2761d..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/nodes/verifier.py +++ /dev/null @@ -1,64 +0,0 @@ -import datetime -from typing import Any - -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.shared_graph_utils.models import BinaryDecision -from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.llm.factory import get_default_llms - - -def verifier(state: VerifierState) -> dict[str, Any]: - """ - Check whether the document is relevant for the original user question - - Args: - state (VerifierState): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---VERIFY QUTPUT---") - node_start_time = datetime.datetime.now() - - question = state["question"] - document_content = state["document"].combined_content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - # Grader - llm, fast_llm = get_default_llms() - response = list( - llm.stream( - prompt=msg, - # structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - response_string = merge_message_runs(response, chunk_separator="")[0].content - # Convert string response to proper dictionary format - decision_dict = {"decision": response_string.lower()} - formatted_response = BinaryDecision.model_validate(decision_dict) - - print(f"Verdict: {formatted_response.decision}") - - return { - "sub_question_verified_retrieval_docs": [state["document"]] - if formatted_response.decision == "yes" - else [], - "log_messages": generate_log_message( - message=f"sub - verifier: {formatted_response.decision}", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/core_qa_graph/states.py b/backend/danswer/agent_search/core_qa_graph/states.py deleted file mode 100644 index 45cec7e58cf..00000000000 --- a/backend/danswer/agent_search/core_qa_graph/states.py +++ /dev/null @@ -1,90 +0,0 @@ -import operator -from collections.abc import Sequence -from datetime import datetime -from typing import Annotated -from typing import TypedDict - -from langchain_core.messages import BaseMessage -from langgraph.graph.message import add_messages - -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.context.search.models import InferenceSection -from danswer.llm.interfaces import LLM - - -class SubQuestionRetrieverState(TypedDict): - # The state for the parallel Retrievers. They each need to see only one query - sub_question_rewritten_query: str - - -class SubQuestionVerifierState(TypedDict): - # The state for the parallel verification step. Each node execution need to see only one question/doc pair - sub_question_document: InferenceSection - sub_question: str - - -class CoreQAInputState(TypedDict): - sub_question_str: str - original_question: str - - -class BaseQAState(TypedDict): - # The 'core SubQuestion' state. - original_question: str - graph_start_time: datetime - # start time for parallel initial sub-questionn thread - sub_query_start_time: datetime - sub_question_rewritten_queries: list[str] - sub_question_str: str - sub_question_search_queries: RewrittenQueries - deduped_retrieval_docs: list[InferenceSection] - sub_question_nr: int - sub_question_base_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_deduped_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_verified_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_reranked_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_top_chunks: Annotated[Sequence[dict], operator.add] - sub_question_answer: str - sub_question_answer_check: str - log_messages: Annotated[Sequence[BaseMessage], add_messages] - sub_qas: Annotated[Sequence[dict], operator.add] - # Answers sent back to core - initial_sub_qas: Annotated[Sequence[dict], operator.add] - primary_llm: LLM - fast_llm: LLM - - -class BaseQAOutputState(TypedDict): - # The 'SubQuestion' output state. Removes all the intermediate states - sub_question_rewritten_queries: list[str] - sub_question_str: str - sub_question_search_queries: list[str] - sub_question_nr: int - # Answers sent back to core - sub_qas: Annotated[Sequence[dict], operator.add] - # Answers sent back to core - initial_sub_qas: Annotated[Sequence[dict], operator.add] - sub_question_base_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_deduped_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_verified_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_reranked_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_top_chunks: Annotated[Sequence[dict], operator.add] - sub_question_answer: str - sub_question_answer_check: str - log_messages: Annotated[Sequence[BaseMessage], add_messages] diff --git a/backend/danswer/agent_search/primary_state.py b/backend/danswer/agent_search/core_state.py similarity index 87% rename from backend/danswer/agent_search/primary_state.py rename to backend/danswer/agent_search/core_state.py index 23898e774cb..cc0057136eb 100644 --- a/backend/danswer/agent_search/primary_state.py +++ b/backend/danswer/agent_search/core_state.py @@ -13,7 +13,3 @@ class PrimaryState(TypedDict, total=False): # a single session for the entire agent search # is fine if we are only reading db_session: Session - - -class InputState(PrimaryState, total=True): - pass diff --git a/backend/danswer/agent_search/core_qa_graph/nodes/__init__.py b/backend/danswer/agent_search/deep_answer/edges.py similarity index 100% rename from backend/danswer/agent_search/core_qa_graph/nodes/__init__.py rename to backend/danswer/agent_search/deep_answer/edges.py diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/__init__.py b/backend/danswer/agent_search/deep_answer/graph_builder.py similarity index 100% rename from backend/danswer/agent_search/deep_qa_graph/nodes/__init__.py rename to backend/danswer/agent_search/deep_answer/graph_builder.py diff --git a/backend/danswer/agent_search/primary_graph/nodes/final_stuff.py b/backend/danswer/agent_search/deep_answer/nodes/answer_generation.py similarity index 57% rename from backend/danswer/agent_search/primary_graph/nodes/final_stuff.py rename to backend/danswer/agent_search/deep_answer/nodes/answer_generation.py index be115de8dda..99389eb4883 100644 --- a/backend/danswer/agent_search/primary_graph/nodes/final_stuff.py +++ b/backend/danswer/agent_search/deep_answer/nodes/answer_generation.py @@ -1,11 +1,58 @@ -from datetime import datetime from typing import Any -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from langchain_core.messages import HumanMessage +from danswer.agent_search.main.states import MainState +from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT +from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs +from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace -def final_stuff(state: QAState) -> dict[str, Any]: + +# aggregate sub questions and answers +def deep_answer_generation(state: MainState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---DEEP GENERATE---") + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + deep_answer_context = state["core_answer_dynamic_context"] + + print(f"Number of verified retrieval docs - deep: {len(docs)}") + + combined_context = normalize_whitespace( + COMBINED_CONTEXT.format( + deep_answer_context=deep_answer_context, formated_docs=format_docs(docs) + ) + ) + + msg = [ + HumanMessage( + content=MODIFIED_RAG_PROMPT.format( + question=question, combined_context=combined_context + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + return { + "deep_answer": response.content, + } + + +def final_stuff(state: MainState) -> dict[str, Any]: """ Invokes the agent model to generate a response based on the current state. Given the question, it will decide to retrieve using the retriever tool, or simply end. @@ -17,7 +64,6 @@ def final_stuff(state: QAState) -> dict[str, Any]: dict: The updated state with the agent response appended to messages """ print("---FINAL---") - node_start_time = datetime.now() messages = state["log_messages"] time_ordered_messages = [x.pretty_repr() for x in messages] @@ -36,15 +82,6 @@ def final_stuff(state: QAState) -> dict[str, Any]: initial_sub_qa_context = "\n".join(initial_sub_qa_list) - log_message = generate_log_message( - message="all - final_stuff", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ) - - print(log_message) - print("--------------------------------") - base_answer = state["base_answer"] print(f"Final Base Answer:\n{base_answer}") @@ -54,9 +91,7 @@ def final_stuff(state: QAState) -> dict[str, Any]: if not state.get("deep_answer"): print("No Deep Answer was required") - return { - "log_messages": log_message, - } + return {} deep_answer = state["deep_answer"] sub_qas = state["sub_qas"] @@ -76,10 +111,4 @@ def final_stuff(state: QAState) -> dict[str, Any]: print("Sub Questions and Answers:") print(sub_qa_context) - return { - "log_messages": generate_log_message( - message="all - final_stuff", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } + return {} diff --git a/backend/danswer/agent_search/primary_graph/nodes/decompose.py b/backend/danswer/agent_search/deep_answer/nodes/deep_decomp.py similarity index 95% rename from backend/danswer/agent_search/primary_graph/nodes/decompose.py rename to backend/danswer/agent_search/deep_answer/nodes/deep_decomp.py index 351d374b464..d61357640a0 100644 --- a/backend/danswer/agent_search/primary_graph/nodes/decompose.py +++ b/backend/danswer/agent_search/deep_answer/nodes/deep_decomp.py @@ -5,13 +5,13 @@ from langchain_core.messages import HumanMessage -from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.main.states import MainState from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction from danswer.agent_search.shared_graph_utils.utils import generate_log_message -def decompose(state: QAState) -> dict[str, Any]: +def decompose(state: MainState) -> dict[str, Any]: """ """ node_start_time = datetime.now() diff --git a/backend/danswer/agent_search/primary_graph/nodes/entity_term_extraction.py b/backend/danswer/agent_search/deep_answer/nodes/entity_term_extraction.py similarity index 53% rename from backend/danswer/agent_search/primary_graph/nodes/entity_term_extraction.py rename to backend/danswer/agent_search/deep_answer/nodes/entity_term_extraction.py index b19de6d4f39..e369707ee5c 100644 --- a/backend/danswer/agent_search/primary_graph/nodes/entity_term_extraction.py +++ b/backend/danswer/agent_search/deep_answer/nodes/entity_term_extraction.py @@ -1,21 +1,17 @@ import json import re -from datetime import datetime from typing import Any from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT -from danswer.agent_search.primary_graph.states import QAState +from danswer.agent_search.main.states import MainState +from danswer.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.llm.factory import get_default_llms -def entity_term_extraction(state: QAState) -> dict[str, Any]: +def entity_term_extraction(state: MainState) -> dict[str, Any]: """Extract entities and terms from the question and context""" - node_start_time = datetime.now() question = state["original_question"] docs = state["deduped_retrieval_docs"] @@ -27,13 +23,11 @@ def entity_term_extraction(state: QAState) -> dict[str, Any]: content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), ) ] - _, fast_llm = get_default_llms() + fast_llm = state["fast_llm"] # Grader llm_response_list = list( fast_llm.stream( prompt=msg, - # structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()}, - # structured_response_format=RewrittenQueries.model_json_schema(), ) ) llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content @@ -43,9 +37,4 @@ def entity_term_extraction(state: QAState) -> dict[str, Any]: return { "retrieved_entities_relationships": parsed_response, - "log_messages": generate_log_message( - message="deep - entity term extraction", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), } diff --git a/backend/danswer/agent_search/primary_graph/nodes/sub_qa_level_aggregator.py b/backend/danswer/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py similarity index 64% rename from backend/danswer/agent_search/primary_graph/nodes/sub_qa_level_aggregator.py rename to backend/danswer/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py index 8d53ccc239b..d384dc51380 100644 --- a/backend/danswer/agent_search/primary_graph/nodes/sub_qa_level_aggregator.py +++ b/backend/danswer/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py @@ -1,16 +1,12 @@ -from datetime import datetime from typing import Any -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message +from danswer.agent_search.main.states import MainState # aggregate sub questions and answers -def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]: +def sub_qa_level_aggregator(state: MainState) -> dict[str, Any]: sub_qas = state["sub_qas"] - node_start_time = datetime.now() - dynamic_context_list = [ "Below you will find useful information to answer the original question:" ] @@ -31,9 +27,4 @@ def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]: return { "core_answer_dynamic_context": dynamic_context, "checked_sub_qas": checked_sub_qas, - "log_messages": generate_log_message( - message="deep - sub qa level aggregator", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), } diff --git a/backend/danswer/agent_search/deep_answer/nodes/sub_qa_manager.py b/backend/danswer/agent_search/deep_answer/nodes/sub_qa_manager.py new file mode 100644 index 00000000000..11167cd04bc --- /dev/null +++ b/backend/danswer/agent_search/deep_answer/nodes/sub_qa_manager.py @@ -0,0 +1,19 @@ +from typing import Any + +from danswer.agent_search.main.states import MainState + + +def sub_qa_manager(state: MainState) -> dict[str, Any]: + """ """ + + sub_questions_dict = state["decomposed_sub_questions_dict"] + + sub_questions = {} + + for sub_question_nr, sub_question_dict in sub_questions_dict.items(): + sub_questions[sub_question_nr] = sub_question_dict["sub_question"] + + return { + "sub_questions": sub_questions, + "num_new_question_iterations": 0, + } diff --git a/backend/danswer/agent_search/primary_graph/nodes/__init__.py b/backend/danswer/agent_search/deep_answer/states.py similarity index 100% rename from backend/danswer/agent_search/primary_graph/nodes/__init__.py rename to backend/danswer/agent_search/deep_answer/states.py diff --git a/backend/danswer/agent_search/deep_qa_graph/edges.py b/backend/danswer/agent_search/deep_qa_graph/edges.py deleted file mode 100644 index 6f0f143aa03..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/edges.py +++ /dev/null @@ -1,41 +0,0 @@ -from collections.abc import Hashable - -from langgraph.types import Send - -from danswer.agent_search.deep_qa_graph.states import ResearchQAState -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.primary_graph.states import VerifierState - - -def continue_to_verifier(state: ResearchQAState) -> list[Hashable]: - # Routes each de-douped retrieved doc to the verifier step - in parallel - # Notice the 'Send()' API that takes care of the parallelization - - return [ - Send( - "sub_verifier", - VerifierState( - document=doc, - question=state["sub_question"], - graph_start_time=state["graph_start_time"], - ), - ) - for doc in state["sub_question_base_retrieval_docs"] - ] - - -def continue_to_retrieval( - state: ResearchQAState, -) -> list[Hashable]: - # Routes re-written queries to the (parallel) retrieval steps - # Notice the 'Send()' API that takes care of the parallelization - return [ - Send( - "sub_custom_retrieve", - RetrieverState( - rewritten_query=query, - graph_start_time=state["graph_start_time"], - ), - ) - for query in state["sub_question_rewritten_queries"] - ] diff --git a/backend/danswer/agent_search/deep_qa_graph/graph_builder.py b/backend/danswer/agent_search/deep_qa_graph/graph_builder.py deleted file mode 100644 index 0e0cae88127..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/graph_builder.py +++ /dev/null @@ -1,90 +0,0 @@ -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph - -from danswer.agent_search.core_qa_graph.nodes.rewrite import rewrite -from danswer.agent_search.deep_qa_graph.edges import continue_to_retrieval -from danswer.agent_search.deep_qa_graph.edges import continue_to_verifier -from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import ( - combine_retrieved_docs, -) -from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import custom_retrieve -from danswer.agent_search.deep_qa_graph.nodes.final_format import final_format -from danswer.agent_search.deep_qa_graph.nodes.generate import generate -from danswer.agent_search.deep_qa_graph.nodes.qa_check import qa_check -from danswer.agent_search.deep_qa_graph.nodes.verifier import verifier -from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState -from danswer.agent_search.deep_qa_graph.states import ResearchQAState - - -def build_deep_qa_graph() -> StateGraph: - # Define the nodes we will cycle between - answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState) - - ### Add Nodes ### - - # Dummy node for initial processing - # answers.add_node(node="dummy", action=dummy) - answers.add_node(node="rewrite", action=rewrite) - - # The retrieval step - answers.add_node(node="custom_retrieve", action=custom_retrieve) - - # The dedupe step - answers.add_node(node="combine_retrieved_docs", action=combine_retrieved_docs) - - # Verifying retrieved information - answers.add_node(node="verifier", action=verifier) - - # Generating the response - answers.add_node(node="generate", action=generate) - - # Checking the quality of the answer - answers.add_node(node="qa_check", action=qa_check) - - # Final formatting of the response - answers.add_node(node="final_format", action=final_format) - - ### Add Edges ### - - # Generate multiple sub-questions - answers.add_edge(start_key=START, end_key="rewrite") - - # For each sub-question, perform a retrieval in parallel - answers.add_conditional_edges( - source="rewrite", - path=continue_to_retrieval, - path_map=["custom_retrieve"], - ) - - # Combine the retrieved docs for each sub-question from the parallel retrievals - answers.add_edge(start_key="custom_retrieve", end_key="combine_retrieved_docs") - - # Go over all of the combined retrieved docs and verify them against the original question - answers.add_conditional_edges( - source="combine_retrieved_docs", - path=continue_to_verifier, - path_map=["verifier"], - ) - - # Generate an answer for each verified retrieved doc - answers.add_edge(start_key="verifier", end_key="generate") - - # Check the quality of the answer - answers.add_edge(start_key="generate", end_key="qa_check") - - answers.add_edge(start_key="qa_check", end_key="final_format") - - answers.add_edge(start_key="final_format", end_key=END) - - return answers - - -if __name__ == "__main__": - # TODO: add the actual question - inputs = {"question": "Whose music is kind of hard to easily enjoy?"} - answers_graph = build_deep_qa_graph() - compiled_answers = answers_graph.compile() - output = compiled_answers.invoke(inputs) - print("\nOUTPUT:") - print(output) diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py deleted file mode 100644 index f2a5be3754a..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/combine_retrieved_docs.py +++ /dev/null @@ -1,31 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.deep_qa_graph.states import ResearchQAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]: - """ - Dedupe the retrieved docs. - """ - node_start_time = datetime.now() - - sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"] - - print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}") - dedupe_docs = [] - for base_retrieval_doc in sub_question_base_retrieval_docs: - if base_retrieval_doc not in dedupe_docs: - dedupe_docs.append(base_retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return { - "sub_question_deduped_retrieval_docs": dedupe_docs, - "log_messages": generate_log_message( - message="sub - combine_retrieved_docs (dedupe)", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py deleted file mode 100644 index 3e00a94a23d..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/custom_retrieve.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.context.search.models import InferenceSection - - -def custom_retrieve(state: RetrieverState) -> dict[str, Any]: - """ - Retrieve documents - - Args: - state (dict): The current graph state - - Returns: - state (dict): New key added to state, documents, that contains retrieved documents - """ - print("---RETRIEVE SUB---") - node_start_time = datetime.now() - - # Retrieval - # TODO: add the actual retrieval, probably from search_tool.run() - documents: list[InferenceSection] = [] - - return { - "sub_question_base_retrieval_docs": documents, - "log_messages": generate_log_message( - message="sub - custom_retrieve", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py b/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py deleted file mode 100644 index e5ada974d23..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py +++ /dev/null @@ -1,21 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def dummy(state: BaseQAState) -> dict[str, Any]: - """ - Dummy step - """ - - print("---Sub Dummy---") - - return { - "log_messages": generate_log_message( - message="sub - dummy", - node_start_time=datetime.now(), - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py b/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py deleted file mode 100644 index 6515c2b4e9e..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/final_format.py +++ /dev/null @@ -1,31 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.deep_qa_graph.states import ResearchQAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def final_format(state: ResearchQAState) -> dict[str, Any]: - """ - Create the final output for the QA subgraph - """ - - print("---SUB FINAL FORMAT---") - node_start_time = datetime.now() - - return { - # TODO: Type this - "sub_qas": [ - { - "sub_question": state["sub_question"], - "sub_answer": state["sub_question_answer"], - "sub_question_nr": state["sub_question_nr"], - "sub_answer_check": state["sub_question_answer_check"], - } - ], - "log_messages": generate_log_message( - message="sub - final format", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py b/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py deleted file mode 100644 index 48faed35adc..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/generate.py +++ /dev/null @@ -1,56 +0,0 @@ -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from danswer.agent_search.deep_qa_graph.states import ResearchQAState -from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def generate(state: ResearchQAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---SUB GENERATE---") - node_start_time = datetime.now() - - question = state["sub_question"] - docs = state["sub_question_verified_retrieval_docs"] - - print(f"Number of verified retrieval docs for sub-question: {len(docs)}") - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) - ) - ] - - # Grader - if len(docs) > 0: - model = state["fast_llm"] - response = list( - model.stream( - prompt=msg, - ) - ) - response_str = merge_message_runs(response, chunk_separator="")[0].content - else: - response_str = "" - - return { - "sub_question_answer": response_str, - "log_messages": generate_log_message( - message="sub - generate", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py b/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py deleted file mode 100644 index 40fe7c70427..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py +++ /dev/null @@ -1,57 +0,0 @@ -import json -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT -from danswer.agent_search.deep_qa_graph.states import ResearchQAState -from danswer.agent_search.shared_graph_utils.models import BinaryDecision -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def qa_check(state: ResearchQAState) -> dict[str, Any]: - """ - Check whether the final output satisfies the original user question - - Args: - state (messages): The current state - - Returns: - dict: The updated state with the final decision - """ - - print("---CHECK SUB QUTPUT---") - node_start_time = datetime.now() - - sub_answer = state["sub_question_answer"] - sub_question = state["sub_question"] - - msg = [ - HumanMessage( - content=SUB_CHECK_PROMPT.format( - sub_question=sub_question, sub_answer=sub_answer - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = list( - model.stream( - prompt=msg, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - raw_response = json.loads(response[0].pretty_repr()) - formatted_response = BinaryDecision.model_validate(raw_response) - - return { - "sub_question_answer_check": formatted_response.decision, - "log_messages": generate_log_message( - message=f"sub - qa check: {formatted_response.decision}", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py b/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py deleted file mode 100644 index f0ac9fbdabd..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py +++ /dev/null @@ -1,64 +0,0 @@ -import json -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.deep_qa_graph.states import ResearchQAState -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.llm.interfaces import LLM - - -def rewrite(state: ResearchQAState) -> dict[str, Any]: - """ - Transform the initial question into more suitable search queries. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - - print("---SUB TRANSFORM QUERY---") - node_start_time = datetime.now() - - question = state["sub_question"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), - ) - ] - fast_llm: LLM = state["fast_llm"] - llm_response = list( - fast_llm.stream( - prompt=msg, - structured_response_format=RewrittenQueries.model_json_schema(), - ) - ) - - # Get the rewritten queries in a defined format - rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) - - print(f"rewritten_queries: {rewritten_queries}") - - rewritten_queries = RewrittenQueries( - rewritten_queries=[ - "music hard to listen to", - "Music that is not fun or pleasant", - ] - ) - - print(f"hardcoded rewritten_queries: {rewritten_queries}") - - return { - "sub_question_rewritten_queries": rewritten_queries, - "log_messages": generate_log_message( - message="sub - rewrite", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py b/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py deleted file mode 100644 index c24f5a36c03..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py +++ /dev/null @@ -1,59 +0,0 @@ -import json -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.shared_graph_utils.models import BinaryDecision -from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def verifier(state: VerifierState) -> dict[str, Any]: - """ - Check whether the document is relevant for the original user question - - Args: - state (VerifierState): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---SUB VERIFY QUTPUT---") - node_start_time = datetime.now() - - question = state["question"] - document_content = state["document"].combined_content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = list( - model.stream( - prompt=msg, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - raw_response = json.loads(response[0].pretty_repr()) - formatted_response = BinaryDecision.model_validate(raw_response) - - return { - "deduped_retrieval_docs": [state["document"]] - if formatted_response.decision == "yes" - else [], - "log_messages": generate_log_message( - message=f"core - verifier: {formatted_response.decision}", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/deep_qa_graph/prompts.py b/backend/danswer/agent_search/deep_qa_graph/prompts.py deleted file mode 100644 index 4e983873b70..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/prompts.py +++ /dev/null @@ -1,13 +0,0 @@ -SUB_CHECK_PROMPT = """ \n - Please check whether the suggested answer seems to address the original question. - - Please only answer with 'yes' or 'no' \n - Here is the initial question: - \n ------- \n - {question} - \n ------- \n - Here is the proposed answer: - \n ------- \n - {base_answer} - \n ------- \n - Please answer with yes or no:""" diff --git a/backend/danswer/agent_search/deep_qa_graph/states.py b/backend/danswer/agent_search/deep_qa_graph/states.py deleted file mode 100644 index 2492f4b4ee5..00000000000 --- a/backend/danswer/agent_search/deep_qa_graph/states.py +++ /dev/null @@ -1,64 +0,0 @@ -import operator -from collections.abc import Sequence -from datetime import datetime -from typing import Annotated -from typing import TypedDict - -from langchain_core.messages import BaseMessage -from langgraph.graph.message import add_messages - -from danswer.context.search.models import InferenceSection -from danswer.llm.interfaces import LLM - - -class ResearchQAState(TypedDict): - # The 'core SubQuestion' state. - original_question: str - graph_start_time: datetime - sub_question_rewritten_queries: list[str] - sub_question: str - sub_question_nr: int - sub_question_base_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_deduped_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_verified_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_reranked_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_top_chunks: Annotated[Sequence[dict], operator.add] - sub_question_answer: str - sub_question_answer_check: str - log_messages: Annotated[Sequence[BaseMessage], add_messages] - sub_qas: Annotated[Sequence[dict], operator.add] - primary_llm: LLM - fast_llm: LLM - - -class ResearchQAOutputState(TypedDict): - # The 'SubQuestion' output state. Removes all the intermediate states - sub_question_rewritten_queries: list[str] - sub_question: str - sub_question_nr: int - # Answers sent back to core - sub_qas: Annotated[Sequence[dict], operator.add] - sub_question_base_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_deduped_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_verified_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_reranked_retrieval_docs: Annotated[ - Sequence[InferenceSection], operator.add - ] - sub_question_top_chunks: Annotated[Sequence[dict], operator.add] - sub_question_answer: str - sub_question_answer_check: str - log_messages: Annotated[Sequence[BaseMessage], add_messages] diff --git a/backend/danswer/agent_search/expanded_retrieval/edges.py b/backend/danswer/agent_search/expanded_retrieval/edges.py index 5aa8357ce5b..80c7dabb522 100644 --- a/backend/danswer/agent_search/expanded_retrieval/edges.py +++ b/backend/danswer/agent_search/expanded_retrieval/edges.py @@ -1,4 +1,4 @@ -from typing import Literal +from collections.abc import Hashable from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs @@ -11,20 +11,11 @@ from danswer.llm.interfaces import LLM -def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> Literal["doc_retrieval"]: - """ - Transform the initial question into more suitable search queries. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - +def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]: print(f"parallel_retrieval_edge state: {state.keys()}") - # messages = state["base_answer_messages"] - question = state["query_to_answer"] + + # This should be better... + question = state.get("query_to_answer") or state["search_request"].query llm: LLM = state["fast_llm"] msg = [ diff --git a/backend/danswer/agent_search/expanded_retrieval/graph_builder.py b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py index ca40bbeb235..68fec83995f 100644 --- a/backend/danswer/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py @@ -4,28 +4,35 @@ from danswer.agent_search.expanded_retrieval.edges import conditionally_rerank_edge from danswer.agent_search.expanded_retrieval.edges import parallel_retrieval_edge -from danswer.agent_search.expanded_retrieval.nodes.collect_retrieved_docs import ( - kick_off_verification, -) from danswer.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval from danswer.agent_search.expanded_retrieval.nodes.doc_verification import ( doc_verification, ) +from danswer.agent_search.expanded_retrieval.nodes.verification_kickoff import ( + verification_kickoff, +) from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState def expanded_retrieval_graph_builder() -> StateGraph: - graph = StateGraph(ExpandedRetrievalState) + graph = StateGraph( + state_schema=ExpandedRetrievalState, + input=ExpandedRetrievalInput, + output=ExpandedRetrievalOutput, + ) + + ### Add nodes ### graph.add_node( node="doc_retrieval", action=doc_retrieval, ) graph.add_node( - node="kick_off_verification", - action=kick_off_verification, + node="verification_kickoff", + action=verification_kickoff, ) graph.add_node( node="doc_verification", @@ -36,13 +43,16 @@ def expanded_retrieval_graph_builder() -> StateGraph: action=doc_reranking, ) + ### Add edges ### + graph.add_conditional_edges( source=START, path=parallel_retrieval_edge, + path_map=["doc_retrieval"], ) graph.add_edge( start_key="doc_retrieval", - end_key="kick_off_verification", + end_key="verification_kickoff", ) graph.add_conditional_edges( source="doc_verification", @@ -77,7 +87,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - query_to_expand="Who made Excel?", + query_to_answer="Who made Excel?", ) for thing in compiled_graph.stream(inputs, debug=True): print(thing) diff --git a/backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py index 05c92d0c2c8..b98de7f8d3b 100644 --- a/backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/danswer/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,50 +1,9 @@ -import random -import time -from datetime import datetime -from unittest.mock import MagicMock - from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState -from danswer.configs.constants import DocumentSource -from danswer.context.search.models import InferenceChunk from danswer.context.search.models import InferenceSection - - -def create_mock_inference_section() -> MagicMock: - # Create a mock InferenceChunk first - mock_chunk = MagicMock(spec=InferenceChunk) - mock_chunk.document_id = f"test_doc_id_{random.randint(1, 1000)}" - mock_chunk.source_type = DocumentSource.FILE - mock_chunk.semantic_identifier = "test_semantic_id" - mock_chunk.title = "Test Title" - mock_chunk.boost = 1 - mock_chunk.recency_bias = 1.0 - mock_chunk.score = 0.95 - mock_chunk.hidden = False - mock_chunk.is_relevant = True - mock_chunk.relevance_explanation = "Test relevance" - mock_chunk.metadata = {"key": "value"} - mock_chunk.match_highlights = ["test highlight"] - mock_chunk.updated_at = datetime.now() - mock_chunk.primary_owners = ["owner1"] - mock_chunk.secondary_owners = ["owner2"] - mock_chunk.large_chunk_reference_ids = [1, 2] - mock_chunk.chunk_id = 1 - mock_chunk.content = "Test content" - mock_chunk.blurb = "Test blurb" - - # Create the InferenceSection mock - mock_section = MagicMock(spec=InferenceSection) - mock_section.center_chunk = mock_chunk - mock_section.chunks = [mock_chunk] - mock_section.combined_content = "Test combined content" - - return mock_section - - -def get_mock_inference_sections() -> list[InferenceSection]: - """Returns a list of mock InferenceSections for testing""" - return [create_mock_inference_section()] +from danswer.context.search.models import SearchRequest +from danswer.context.search.pipeline import SearchPipeline +from danswer.db.engine import get_session_context_manager class RetrieveInput(ExpandedRetrievalState): @@ -67,38 +26,22 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: state["query_to_retrieve"] documents: list[InferenceSection] = [] - state["primary_llm"] - state["fast_llm"] + llm = state["primary_llm"] + fast_llm = state["fast_llm"] # db_session = state["db_session"] - - # from danswer.db.engine import get_session_context_manager - # with get_session_context_manager() as db_session1: - # documents = SearchPipeline( - # search_request=SearchRequest( - # query=query_to_retrieve, - # ), - # user=None, - # llm=llm, - # fast_llm=fast_llm, - # db_session=db_session1, - # ).reranked_sections - - time.sleep(random.random() * 10) - - documents = get_mock_inference_sections() - - print(f"documents: {documents}") - - # return Command( - # update={"retrieved_documents": documents}, - # goto=Send( - # node="doc_verification", - # arg=DocVerificationInput( - # doc_to_verify=documents, - # **state - # ), - # ), - # ) + query_to_retrieve = state["search_request"].query + with get_session_context_manager() as db_session1: + documents = SearchPipeline( + search_request=SearchRequest( + query=query_to_retrieve, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=db_session1, + ).reranked_sections + + print(f"retrieved documents: {len(documents)}") return DocRetrievalOutput( retrieved_documents=documents, ) diff --git a/backend/danswer/agent_search/expanded_retrieval/nodes/collect_retrieved_docs.py b/backend/danswer/agent_search/expanded_retrieval/nodes/verification_kickoff.py similarity index 88% rename from backend/danswer/agent_search/expanded_retrieval/nodes/collect_retrieved_docs.py rename to backend/danswer/agent_search/expanded_retrieval/nodes/verification_kickoff.py index 4bd882a5436..56e96b8f9fa 100644 --- a/backend/danswer/agent_search/expanded_retrieval/nodes/collect_retrieved_docs.py +++ b/backend/danswer/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -9,10 +9,10 @@ from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState -def kick_off_verification( +def verification_kickoff( state: ExpandedRetrievalState, ) -> Command[Literal["doc_verification"]]: - print(f"kick_off_verification state: {state.keys()}") + print(f"verification_kickoff state: {state.keys()}") documents = state["retrieved_documents"] return Command( diff --git a/backend/danswer/agent_search/expanded_retrieval/states.py b/backend/danswer/agent_search/expanded_retrieval/states.py index 823ffaba41c..04811c32652 100644 --- a/backend/danswer/agent_search/expanded_retrieval/states.py +++ b/backend/danswer/agent_search/expanded_retrieval/states.py @@ -1,16 +1,9 @@ from typing import Annotated from typing import TypedDict -from danswer.agent_search.primary_state import PrimaryState +from danswer.agent_search.core_state import PrimaryState +from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections from danswer.context.search.models import InferenceSection -from danswer.llm.answering.prune_and_merge import _merge_sections - - -def dedup_inference_sections( - list1: list[InferenceSection], list2: list[InferenceSection] -) -> list[InferenceSection]: - deduped = _merge_sections(list1 + list2) - return deduped class DocRetrievalOutput(TypedDict, total=False): diff --git a/backend/danswer/agent_search/main/edges.py b/backend/danswer/agent_search/main/edges.py new file mode 100644 index 00000000000..0f9f6c09131 --- /dev/null +++ b/backend/danswer/agent_search/main/edges.py @@ -0,0 +1,61 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from danswer.agent_search.answer_query.states import AnswerQueryInput +from danswer.agent_search.main.states import MainState + + +def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: + return [ + Send( + "answer_query", + AnswerQueryInput( + **state, + query_to_answer=query, + ), + ) + for query in state["initial_decomp_queries"] + ] + + +# def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: +# # Routes re-written queries to the (parallel) retrieval steps +# # Notice the 'Send()' API that takes care of the parallelization +# return [ +# Send( +# "sub_answers_graph", +# ResearchQAState( +# sub_question=sub_question["sub_question_str"], +# sub_question_nr=sub_question["sub_question_nr"], +# graph_start_time=state["graph_start_time"], +# primary_llm=state["primary_llm"], +# fast_llm=state["fast_llm"], +# ), +# ) +# for sub_question in state["sub_questions"] +# ] + + +# def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]: +# print("---GO TO DEEP ANSWER OR END---") + +# base_answer = state["base_answer"] + +# question = state["original_question"] + +# BASE_CHECK_MESSAGE = [ +# HumanMessage( +# content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) +# ) +# ] + +# model = state["fast_llm"] +# response = model.invoke(BASE_CHECK_MESSAGE) + +# print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}") + +# if response.pretty_repr() == "no": +# return "decompose" +# else: +# return "end" diff --git a/backend/danswer/agent_search/main/graph_builder.py b/backend/danswer/agent_search/main/graph_builder.py new file mode 100644 index 00000000000..6c783e619d2 --- /dev/null +++ b/backend/danswer/agent_search/main/graph_builder.py @@ -0,0 +1,99 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from danswer.agent_search.answer_query.graph_builder import answer_query_graph_builder +from danswer.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) +from danswer.agent_search.main.edges import parallelize_decompozed_answer_queries +from danswer.agent_search.main.nodes.base_decomp import main_decomp_base +from danswer.agent_search.main.nodes.generate_initial_answer import ( + generate_initial_answer, +) +from danswer.agent_search.main.states import MainInput +from danswer.agent_search.main.states import MainState + + +def main_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=MainState, + input=MainInput, + ) + + ### Add nodes ### + + graph.add_node( + node="base_decomp", + action=main_decomp_base, + ) + answer_query_subgraph = answer_query_graph_builder().compile() + graph.add_node( + node="answer_query", + action=answer_query_subgraph, + ) + expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="expanded_retrieval", + action=expanded_retrieval_subgraph, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) + + ### Add edges ### + graph.add_edge( + start_key=START, + end_key="expanded_retrieval", + ) + graph.add_edge( + start_key="expanded_retrieval", + end_key="generate_initial_answer", + ) + + graph.add_edge( + start_key=START, + end_key="base_decomp", + ) + graph.add_conditional_edges( + source="base_decomp", + path=parallelize_decompozed_answer_queries, + path_map=["answer_query"], + ) + graph.add_edge( + start_key="answer_query", + end_key="generate_initial_answer", + ) + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from danswer.db.engine import get_session_context_manager + from danswer.llm.factory import get_default_llms + from danswer.context.search.models import SearchRequest + + graph = main_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="Who made Excel and what other products did they make?", + ) + with get_session_context_manager() as db_session: + inputs = MainInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + ) + output = compiled_graph.invoke( + input=inputs, + # debug=True, + # subgraphs=True, + ) + print(output) diff --git a/backend/danswer/agent_search/main/nodes/base_decomp.py b/backend/danswer/agent_search/main/nodes/base_decomp.py new file mode 100644 index 00000000000..ce3dbd10818 --- /dev/null +++ b/backend/danswer/agent_search/main/nodes/base_decomp.py @@ -0,0 +1,31 @@ +from langchain_core.messages import HumanMessage + +from danswer.agent_search.main.states import BaseDecompOutput +from danswer.agent_search.main.states import MainState +from danswer.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT +from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string + + +def main_decomp_base(state: MainState) -> BaseDecompOutput: + question = state["search_request"].query + + msg = [ + HumanMessage( + content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), + ) + ] + + # Get the rewritten queries in a defined format + model = state["fast_llm"] + response = model.invoke(msg) + + content = response.pretty_repr() + list_of_subquestions = clean_and_parse_list_string(content) + + decomp_list: list[str] = [ + sub_question["sub_question"].strip() for sub_question in list_of_subquestions + ] + + return BaseDecompOutput( + initial_decomp_queries=decomp_list, + ) diff --git a/backend/danswer/agent_search/main/nodes/generate_initial_answer.py b/backend/danswer/agent_search/main/nodes/generate_initial_answer.py new file mode 100644 index 00000000000..96e48bc1bda --- /dev/null +++ b/backend/danswer/agent_search/main/nodes/generate_initial_answer.py @@ -0,0 +1,51 @@ +from langchain_core.messages import HumanMessage + +from danswer.agent_search.main.states import InitialAnswerOutput +from danswer.agent_search.main.states import MainState +from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.utils import format_docs + + +def generate_initial_answer(state: MainState) -> InitialAnswerOutput: + print("---GENERATE INITIAL---") + + question = state["search_request"].query + docs = state["documents"] + + decomp_answer_results = state["decomp_answer_results"] + + good_qa_list: list[str] = [] + + _SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + for decomp_answer_result in decomp_answer_results: + if ( + decomp_answer_result.quality == "yes" + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != "I don't know" + ): + good_qa_list.append( + _SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.query, + sub_answer=decomp_answer_result.answer, + ) + ) + + sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT.format( + question=question, + context=format_docs(docs), + answered_sub_questions=sub_question_answer_str, + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + return InitialAnswerOutput(initial_answer=response.pretty_repr()) diff --git a/backend/danswer/agent_search/main/states.py b/backend/danswer/agent_search/main/states.py new file mode 100644 index 00000000000..e6124f5a5eb --- /dev/null +++ b/backend/danswer/agent_search/main/states.py @@ -0,0 +1,37 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + +from danswer.agent_search.answer_query.states import SearchAnswerResults +from danswer.agent_search.core_state import PrimaryState +from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections +from danswer.context.search.models import InferenceSection + + +class BaseDecompOutput(TypedDict, total=False): + initial_decomp_queries: list[str] + + +class InitialAnswerOutput(TypedDict, total=False): + initial_answer: str + + +class MainState( + PrimaryState, + BaseDecompOutput, + InitialAnswerOutput, + total=True, +): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + decomp_answer_results: Annotated[list[SearchAnswerResults], add] + + +class MainInput(PrimaryState, total=True): + pass + + +class MainOutput(TypedDict): + """ + This is not used because defining the output only matters for filtering the output of + a .invoke() call but we are streaming so we just yield the entire state. + """ diff --git a/backend/danswer/agent_search/primary_graph/edges.py b/backend/danswer/agent_search/primary_graph/edges.py deleted file mode 100644 index 495ae807f79..00000000000 --- a/backend/danswer/agent_search/primary_graph/edges.py +++ /dev/null @@ -1,75 +0,0 @@ -from collections.abc import Hashable -from typing import Union - -from langchain_core.messages import HumanMessage -from langgraph.types import Send - -from danswer.agent_search.core_qa_graph.states import BaseQAState -from danswer.agent_search.deep_qa_graph.states import ResearchQAState -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT - - -def continue_to_initial_sub_questions( - state: QAState, -) -> Union[Hashable, list[Hashable]]: - # Routes re-written queries to the (parallel) retrieval steps - # Notice the 'Send()' API that takes care of the parallelization - return [ - Send( - "sub_answers_graph_initial", - BaseQAState( - sub_question_str=initial_sub_question["sub_question_str"], - sub_question_search_queries=initial_sub_question[ - "sub_question_search_queries" - ], - sub_question_nr=initial_sub_question["sub_question_nr"], - primary_llm=state["primary_llm"], - fast_llm=state["fast_llm"], - graph_start_time=state["graph_start_time"], - ), - ) - for initial_sub_question in state["initial_sub_questions"] - ] - - -def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: - # Routes re-written queries to the (parallel) retrieval steps - # Notice the 'Send()' API that takes care of the parallelization - return [ - Send( - "sub_answers_graph", - ResearchQAState( - sub_question=sub_question["sub_question_str"], - sub_question_nr=sub_question["sub_question_nr"], - graph_start_time=state["graph_start_time"], - primary_llm=state["primary_llm"], - fast_llm=state["fast_llm"], - ), - ) - for sub_question in state["sub_questions"] - ] - - -def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]: - print("---GO TO DEEP ANSWER OR END---") - - base_answer = state["base_answer"] - - question = state["original_question"] - - BASE_CHECK_MESSAGE = [ - HumanMessage( - content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) - ) - ] - - model = state["fast_llm"] - response = model.invoke(BASE_CHECK_MESSAGE) - - print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}") - - if response.pretty_repr() == "no": - return "decompose" - else: - return "end" diff --git a/backend/danswer/agent_search/primary_graph/graph_builder.py b/backend/danswer/agent_search/primary_graph/graph_builder.py deleted file mode 100644 index 6bce818514c..00000000000 --- a/backend/danswer/agent_search/primary_graph/graph_builder.py +++ /dev/null @@ -1,145 +0,0 @@ -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph - -from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph -from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph -from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions -from danswer.agent_search.primary_graph.edges import continue_to_deep_answer -from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions -from danswer.agent_search.primary_graph.nodes.base_wait import base_wait -from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import ( - combine_retrieved_docs, -) -from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve -from danswer.agent_search.primary_graph.nodes.decompose import decompose -from danswer.agent_search.primary_graph.nodes.deep_answer_generation import ( - deep_answer_generation, -) -from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start -from danswer.agent_search.primary_graph.nodes.entity_term_extraction import ( - entity_term_extraction, -) -from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff -from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial -from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base -from danswer.agent_search.primary_graph.nodes.rewrite import rewrite -from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import ( - sub_qa_level_aggregator, -) -from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager -from danswer.agent_search.primary_graph.nodes.verifier import verifier -from danswer.agent_search.primary_graph.states import QAState - - -def build_core_graph() -> StateGraph: - # Define the nodes we will cycle between - core_answer_graph = StateGraph(state_schema=QAState) - - ### Add Nodes ### - core_answer_graph.add_node(node="dummy_start", action=dummy_start) - - # Re-writing the question - core_answer_graph.add_node(node="rewrite", action=rewrite) - - # The retrieval step - core_answer_graph.add_node(node="custom_retrieve", action=custom_retrieve) - - # Combine and dedupe retrieved docs. - core_answer_graph.add_node( - node="combine_retrieved_docs", action=combine_retrieved_docs - ) - - # Extract entities, terms and relationships - core_answer_graph.add_node( - node="entity_term_extraction", action=entity_term_extraction - ) - - # Verifying that a retrieved doc is relevant - core_answer_graph.add_node(node="verifier", action=verifier) - - # Initial question decomposition - core_answer_graph.add_node(node="main_decomp_base", action=main_decomp_base) - - # Build the base QA sub-graph and compile it - compiled_core_qa_graph = build_core_qa_graph().compile() - # Add the compiled base QA sub-graph as a node to the core graph - core_answer_graph.add_node( - node="sub_answers_graph_initial", action=compiled_core_qa_graph - ) - - # Checking whether the initial answer is in the ballpark - core_answer_graph.add_node(node="base_wait", action=base_wait) - - # Decompose the question into sub-questions - core_answer_graph.add_node(node="decompose", action=decompose) - - # Manage the sub-questions - core_answer_graph.add_node(node="sub_qa_manager", action=sub_qa_manager) - - # Build the research QA sub-graph and compile it - compiled_deep_qa_graph = build_deep_qa_graph().compile() - # Add the compiled research QA sub-graph as a node to the core graph - core_answer_graph.add_node(node="sub_answers_graph", action=compiled_deep_qa_graph) - - # Aggregate the sub-questions - core_answer_graph.add_node( - node="sub_qa_level_aggregator", action=sub_qa_level_aggregator - ) - - # aggregate sub questions and answers - core_answer_graph.add_node( - node="deep_answer_generation", action=deep_answer_generation - ) - - # A final clean-up step - core_answer_graph.add_node(node="final_stuff", action=final_stuff) - - # Generating a response after we know the documents are relevant - core_answer_graph.add_node(node="generate_initial", action=generate_initial) - - ### Add Edges ### - - # start the initial sub-question decomposition - core_answer_graph.add_edge(start_key=START, end_key="main_decomp_base") - core_answer_graph.add_conditional_edges( - source="main_decomp_base", - path=continue_to_initial_sub_questions, - ) - - # use the retrieved information to generate the answer - core_answer_graph.add_edge( - start_key=["verifier", "sub_answers_graph_initial"], end_key="generate_initial" - ) - core_answer_graph.add_edge(start_key="generate_initial", end_key="base_wait") - - core_answer_graph.add_conditional_edges( - source="base_wait", - path=continue_to_deep_answer, - path_map={"decompose": "entity_term_extraction", "end": "final_stuff"}, - ) - - core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose") - - core_answer_graph.add_edge(start_key="decompose", end_key="sub_qa_manager") - core_answer_graph.add_conditional_edges( - source="sub_qa_manager", - path=continue_to_answer_sub_questions, - ) - - core_answer_graph.add_edge( - start_key="sub_answers_graph", end_key="sub_qa_level_aggregator" - ) - - core_answer_graph.add_edge( - start_key="sub_qa_level_aggregator", end_key="deep_answer_generation" - ) - - core_answer_graph.add_edge( - start_key="deep_answer_generation", end_key="final_stuff" - ) - - core_answer_graph.add_edge(start_key="final_stuff", end_key=END) - core_answer_graph.compile() - - return core_answer_graph diff --git a/backend/danswer/agent_search/primary_graph/nodes/base_wait.py b/backend/danswer/agent_search/primary_graph/nodes/base_wait.py deleted file mode 100644 index e46aa530f2c..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/base_wait.py +++ /dev/null @@ -1,27 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def base_wait(state: QAState) -> dict[str, Any]: - """ - Ensures that all required steps are completed before proceeding to the next step - - Args: - state (messages): The current state - - Returns: - dict: {} (no operation, just logging) - """ - - print("---Base Wait ---") - node_start_time = datetime.now() - return { - "log_messages": generate_log_message( - message="core - base_wait", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py b/backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py deleted file mode 100644 index a175b74e5a6..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/combine_retrieved_docs.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Sequence -from datetime import datetime -from typing import Any - -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.context.search.models import InferenceSection - - -def combine_retrieved_docs(state: QAState) -> dict[str, Any]: - """ - Dedupe the retrieved docs. - """ - node_start_time = datetime.now() - - base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"] - - print(f"Number of docs from steps: {len(base_retrieval_docs)}") - dedupe_docs: list[InferenceSection] = [] - for base_retrieval_doc in base_retrieval_docs: - if not any( - base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id - for doc in dedupe_docs - ): - dedupe_docs.append(base_retrieval_doc) - - print(f"Number of deduped docs: {len(dedupe_docs)}") - - return { - "deduped_retrieval_docs": dedupe_docs, - "log_messages": generate_log_message( - message="core - combine_retrieved_docs (dedupe)", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py b/backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py deleted file mode 100644 index deaafbdf411..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/custom_retrieve.py +++ /dev/null @@ -1,52 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.primary_graph.states import RetrieverState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.context.search.models import InferenceSection -from danswer.context.search.models import SearchRequest -from danswer.context.search.pipeline import SearchPipeline -from danswer.db.engine import get_session_context_manager -from danswer.llm.factory import get_default_llms - - -def custom_retrieve(state: RetrieverState) -> dict[str, Any]: - """ - Retrieve documents - - Args: - retriever_state (dict): The current graph state - - Returns: - state (dict): New key added to state, documents, that contains retrieved documents - """ - print("---RETRIEVE---") - - node_start_time = datetime.now() - - query = state["rewritten_query"] - - # Retrieval - # TODO: add the actual retrieval, probably from search_tool.run() - llm, fast_llm = get_default_llms() - with get_session_context_manager() as db_session: - top_sections = SearchPipeline( - search_request=SearchRequest( - query=query, - ), - user=None, - llm=llm, - fast_llm=fast_llm, - db_session=db_session, - ).reranked_sections - print(len(top_sections)) - documents: list[InferenceSection] = [] - - return { - "base_retrieval_docs": documents, - "log_messages": generate_log_message( - message="core - custom_retrieve", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py b/backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py deleted file mode 100644 index 55d41162e5b..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/deep_answer_generation.py +++ /dev/null @@ -1,61 +0,0 @@ -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT -from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message -from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace - - -# aggregate sub questions and answers -def deep_answer_generation(state: QAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---DEEP GENERATE---") - - node_start_time = datetime.now() - - question = state["original_question"] - docs = state["deduped_retrieval_docs"] - - deep_answer_context = state["core_answer_dynamic_context"] - - print(f"Number of verified retrieval docs - deep: {len(docs)}") - - combined_context = normalize_whitespace( - COMBINED_CONTEXT.format( - deep_answer_context=deep_answer_context, formated_docs=format_docs(docs) - ) - ) - - msg = [ - HumanMessage( - content=MODIFIED_RAG_PROMPT.format( - question=question, combined_context=combined_context - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - - return { - "deep_answer": response.content, - "log_messages": generate_log_message( - message="deep - deep answer generation", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/dummy_start.py b/backend/danswer/agent_search/primary_graph/nodes/dummy_start.py deleted file mode 100644 index 62e3dd92a7d..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/dummy_start.py +++ /dev/null @@ -1,11 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.primary_graph.states import QAState - - -def dummy_start(state: QAState) -> dict[str, Any]: - """ - Dummy node to set the start time - """ - return {"start_time": datetime.now()} diff --git a/backend/danswer/agent_search/primary_graph/nodes/generate.py b/backend/danswer/agent_search/primary_graph/nodes/generate.py deleted file mode 100644 index 9ff707177d8..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/generate.py +++ /dev/null @@ -1,52 +0,0 @@ -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def generate(state: QAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---GENERATE---") - node_start_time = datetime.now() - - question = state["original_question"] - docs = state["deduped_retrieval_docs"] - - print(f"Number of verified retrieval docs: {len(docs)}") - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) - ) - ] - - # Grader - llm = state["fast_llm"] - response = list( - llm.stream( - prompt=msg, - structured_response_format=None, - ) - ) - - return { - "base_answer": response[0].pretty_repr(), - "log_messages": generate_log_message( - message="core - generate", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/generate_initial.py b/backend/danswer/agent_search/primary_graph/nodes/generate_initial.py deleted file mode 100644 index 56ad83de96e..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/generate_initial.py +++ /dev/null @@ -1,72 +0,0 @@ -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.utils import format_docs -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def generate_initial(state: QAState) -> dict[str, Any]: - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---GENERATE INITIAL---") - node_start_time = datetime.now() - - question = state["original_question"] - docs = state["deduped_retrieval_docs"] - print(f"Number of verified retrieval docs - base: {len(docs)}") - - sub_question_answers = state["initial_sub_qas"] - - sub_question_answers_list = [] - - _SUB_QUESTION_ANSWER_TEMPLATE = """ - Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n - """ - for sub_question_answer_dict in sub_question_answers: - if ( - sub_question_answer_dict["sub_answer_check"] == "yes" - and len(sub_question_answer_dict["sub_answer"]) > 0 - and sub_question_answer_dict["sub_answer"] != "I don't know" - ): - sub_question_answers_list.append( - _SUB_QUESTION_ANSWER_TEMPLATE.format( - sub_question=sub_question_answer_dict["sub_question"], - sub_answer=sub_question_answer_dict["sub_answer"], - ) - ) - - sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list) - - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT.format( - question=question, - context=format_docs(docs), - answered_sub_questions=sub_question_answer_str, - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - - return { - "base_answer": response.pretty_repr(), - "log_messages": generate_log_message( - message="core - generate initial", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py b/backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py deleted file mode 100644 index f07100d61af..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/main_decomp_base.py +++ /dev/null @@ -1,64 +0,0 @@ -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def main_decomp_base(state: QAState) -> dict[str, Any]: - """ - Perform an initial question decomposition, incl. one search term - - Args: - state (messages): The current state - - Returns: - dict: The updated state with initial decomposition - """ - - print("---INITIAL DECOMP---") - node_start_time = datetime.now() - - question = state["original_question"] - - msg = [ - HumanMessage( - content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), - ) - ] - - # Get the rewritten queries in a defined format - model = state["fast_llm"] - response = model.invoke(msg) - - content = response.pretty_repr() - list_of_subquestions = clean_and_parse_list_string(content) - - decomp_list = [] - - for sub_question_nr, sub_question in enumerate(list_of_subquestions): - sub_question_str = sub_question["sub_question"].strip() - # temporarily - sub_question_search_queries = [sub_question["search_term"]] - - decomp_list.append( - { - "sub_question_str": sub_question_str, - "sub_question_search_queries": sub_question_search_queries, - "sub_question_nr": sub_question_nr, - } - ) - - return { - "initial_sub_questions": decomp_list, - "sub_query_start_time": node_start_time, - "log_messages": generate_log_message( - message="core - initial decomp", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/rewrite.py b/backend/danswer/agent_search/primary_graph/nodes/rewrite.py deleted file mode 100644 index 07cbba5432c..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/rewrite.py +++ /dev/null @@ -1,55 +0,0 @@ -import json -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def rewrite(state: QAState) -> dict[str, Any]: - """ - Transform the initial question into more suitable search queries. - - Args: - qa_state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - print("---STARTING GRAPH---") - graph_start_time = datetime.now() - - print("---TRANSFORM QUERY---") - node_start_time = datetime.now() - - question = state["original_question"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), - ) - ] - - # Get the rewritten queries in a defined format - fast_llm = state["fast_llm"] - llm_response = list( - fast_llm.stream( - prompt=msg, - structured_response_format=RewrittenQueries.model_json_schema(), - ) - ) - - formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr()) - - return { - "rewritten_queries": formatted_response.rewritten_queries, - "log_messages": generate_log_message( - message="core - rewrite", - node_start_time=node_start_time, - graph_start_time=graph_start_time, - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py b/backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py deleted file mode 100644 index 6e81dfb5dea..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/sub_qa_manager.py +++ /dev/null @@ -1,28 +0,0 @@ -from datetime import datetime -from typing import Any - -from danswer.agent_search.primary_graph.states import QAState -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def sub_qa_manager(state: QAState) -> dict[str, Any]: - """ """ - - node_start_time = datetime.now() - - sub_questions_dict = state["decomposed_sub_questions_dict"] - - sub_questions = {} - - for sub_question_nr, sub_question_dict in sub_questions_dict.items(): - sub_questions[sub_question_nr] = sub_question_dict["sub_question"] - - return { - "sub_questions": sub_questions, - "num_new_question_iterations": 0, - "log_messages": generate_log_message( - message="deep - sub qa manager", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/nodes/verifier.py b/backend/danswer/agent_search/primary_graph/nodes/verifier.py deleted file mode 100644 index 1efdba03109..00000000000 --- a/backend/danswer/agent_search/primary_graph/nodes/verifier.py +++ /dev/null @@ -1,59 +0,0 @@ -import json -from datetime import datetime -from typing import Any - -from langchain_core.messages import HumanMessage - -from danswer.agent_search.primary_graph.states import VerifierState -from danswer.agent_search.shared_graph_utils.models import BinaryDecision -from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from danswer.agent_search.shared_graph_utils.utils import generate_log_message - - -def verifier(state: VerifierState) -> dict[str, Any]: - """ - Check whether the document is relevant for the original user question - - Args: - state (VerifierState): The current state - - Returns: - dict: ict: The updated state with the final decision - """ - - print("---VERIFY QUTPUT---") - node_start_time = datetime.now() - - question = state["question"] - document_content = state["document"].combined_content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - # Grader - llm = state["fast_llm"] - response = list( - llm.stream( - prompt=msg, - structured_response_format=BinaryDecision.model_json_schema(), - ) - ) - - raw_response = json.loads(response[0].pretty_repr()) - formatted_response = BinaryDecision.model_validate(raw_response) - - return { - "deduped_retrieval_docs": [state["document"]] - if formatted_response.decision == "yes" - else [], - "log_messages": generate_log_message( - message=f"core - verifier: {formatted_response.decision}", - node_start_time=node_start_time, - graph_start_time=state["graph_start_time"], - ), - } diff --git a/backend/danswer/agent_search/primary_graph/prompts.py b/backend/danswer/agent_search/primary_graph/prompts.py deleted file mode 100644 index 0eafd70d275..00000000000 --- a/backend/danswer/agent_search/primary_graph/prompts.py +++ /dev/null @@ -1,86 +0,0 @@ -INITIAL_DECOMPOSITION_PROMPT = """ \n - Please decompose an initial user question into not more than 4 appropriate sub-questions that help to - answer the original question. The purpose for this decomposition is to isolate individulal entities - (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales - for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our - sales with company A' + 'what is our market share with company A' + 'is company A a reference customer - for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n - - For each sub-question, please also create one search term that can be used to retrieve relevant - documents from a document store. - - Here is the initial question: - \n ------- \n - {question} - \n ------- \n - - Please formulate your answer as a list of json objects with the following format: - - [{{"sub_question": , "search_term": }}, ...] - - Answer: - """ - -INITIAL_RAG_PROMPT = """ \n - You are an assistant for question-answering tasks. Use the information provided below - and only the - provided information - to answer the provided question. - - The information provided below consists of: - 1) a number of answered sub-questions - these are very important(!) and definitely should be - considered to answer the question. - 2) a number of documents that were also deemed relevant for the question. - - If you don't know the answer or if the provided information is empty or insufficient, just say - "I don't know". Do not use your internal knowledge! - - Again, only use the provided informationand do not use your internal knowledge! It is a matter of life - and death that you do NOT use your internal knowledge, just the provided information! - - Try to keep your answer concise. - - And here is the question and the provided information: - \n - \nQuestion:\n {question} - - \nAnswered Sub-questions:\n {answered_sub_questions} - - \nContext:\n {context} \n\n - \n\n - - Answer:""" - -ENTITY_TERM_PROMPT = """ \n - Based on the original question and the context retieved from a dataset, please generate a list of - entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts - (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. - - \n\n - Here is the original question: - \n ------- \n - {question} - \n ------- \n - And here is the context retrieved: - \n ------- \n - {context} - \n ------- \n - - Please format your answer as a json object in the following format: - - {{"retrieved_entities_relationships": {{ - "entities": [{{ - "entity_name": , - "entity_type": - }}], - "relationships": [{{ - "name": , - "type": , - "entities": [, ] - }}], - "terms": [{{ - "term_name": , - "term_type": , - "similar_to": - }}] - }} - }} - """ diff --git a/backend/danswer/agent_search/primary_graph/states.py b/backend/danswer/agent_search/primary_graph/states.py deleted file mode 100644 index 2e59fedfcd7..00000000000 --- a/backend/danswer/agent_search/primary_graph/states.py +++ /dev/null @@ -1,73 +0,0 @@ -import operator -from collections.abc import Sequence -from datetime import datetime -from typing import Annotated -from typing import TypedDict - -from langchain_core.messages import BaseMessage -from langgraph.graph.message import add_messages - -from danswer.agent_search.shared_graph_utils.models import RewrittenQueries -from danswer.context.search.models import InferenceSection - - -class QAState(TypedDict): - # The 'main' state of the answer graph - original_question: str - graph_start_time: datetime - # start time for parallel initial sub-questionn thread - sub_query_start_time: datetime - log_messages: Annotated[Sequence[BaseMessage], add_messages] - rewritten_queries: RewrittenQueries - sub_questions: list[dict] - initial_sub_questions: list[dict] - ranked_subquestion_ids: list[int] - decomposed_sub_questions_dict: dict - rejected_sub_questions: Annotated[list[str], operator.add] - rejected_sub_questions_handled: bool - sub_qas: Annotated[Sequence[dict], operator.add] - initial_sub_qas: Annotated[Sequence[dict], operator.add] - checked_sub_qas: Annotated[Sequence[dict], operator.add] - base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] - deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] - reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] - retrieved_entities_relationships: dict - questions_context: list[dict] - qa_level: int - top_chunks: list[InferenceSection] - sub_question_top_chunks: Annotated[Sequence[dict], operator.add] - num_new_question_iterations: int - core_answer_dynamic_context: str - dynamic_context: str - initial_base_answer: str - base_answer: str - deep_answer: str - - -class QAOuputState(TypedDict): - # The 'main' output state of the answer graph. Removes all the intermediate states - original_question: str - log_messages: Annotated[Sequence[BaseMessage], add_messages] - sub_questions: list[dict] - sub_qas: Annotated[Sequence[dict], operator.add] - initial_sub_qas: Annotated[Sequence[dict], operator.add] - checked_sub_qas: Annotated[Sequence[dict], operator.add] - reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add] - retrieved_entities_relationships: dict - top_chunks: list[InferenceSection] - sub_question_top_chunks: Annotated[Sequence[dict], operator.add] - base_answer: str - deep_answer: str - - -class RetrieverState(TypedDict): - # The state for the parallel Retrievers. They each need to see only one query - rewritten_query: str - graph_start_time: datetime - - -class VerifierState(TypedDict): - # The state for the parallel verification step. Each node execution need to see only one question/doc pair - document: InferenceSection - question: str - graph_start_time: datetime diff --git a/backend/danswer/agent_search/shared_graph_utils/models.py b/backend/danswer/agent_search/shared_graph_utils/models.py index ed731fbc566..162d651fe51 100644 --- a/backend/danswer/agent_search/shared_graph_utils/models.py +++ b/backend/danswer/agent_search/shared_graph_utils/models.py @@ -10,7 +10,3 @@ class RewrittenQueries(BaseModel): class BinaryDecision(BaseModel): decision: Literal["yes", "no"] - - -class SubQuestions(BaseModel): - sub_questions: list[str] diff --git a/backend/danswer/agent_search/shared_graph_utils/operators.py b/backend/danswer/agent_search/shared_graph_utils/operators.py new file mode 100644 index 00000000000..f6e5c91ebda --- /dev/null +++ b/backend/danswer/agent_search/shared_graph_utils/operators.py @@ -0,0 +1,9 @@ +from danswer.context.search.models import InferenceSection +from danswer.llm.answering.prune_and_merge import _merge_sections + + +def dedup_inference_sections( + list1: list[InferenceSection], list2: list[InferenceSection] +) -> list[InferenceSection]: + deduped = _merge_sections(list1 + list2) + return deduped diff --git a/backend/danswer/agent_search/shared_graph_utils/prompts.py b/backend/danswer/agent_search/shared_graph_utils/prompts.py index 3fb690c3cb4..f5b40a0ecb0 100644 --- a/backend/danswer/agent_search/shared_graph_utils/prompts.py +++ b/backend/danswer/agent_search/shared_graph_utils/prompts.py @@ -327,3 +327,91 @@ {{"reasonning": , "ranked_motivations": }} """ + + +INITIAL_DECOMPOSITION_PROMPT = """ \n + Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + For each sub-question, please also create one search term that can be used to retrieve relevant + documents from a document store. + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of json objects with the following format: + + [{{"sub_question": , "search_term": }}, ...] + + Answer: + """ + +INITIAL_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the information provided below - and only the + provided information - to answer the provided question. + + The information provided below consists of: + 1) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 2) a number of documents that were also deemed relevant for the question. + + If you don't know the answer or if the provided information is empty or insufficient, just say + "I don't know". Do not use your internal knowledge! + + Again, only use the provided informationand do not use your internal knowledge! It is a matter of life + and death that you do NOT use your internal knowledge, just the provided information! + + Try to keep your answer concise. + + And here is the question and the provided information: + \n + \nQuestion:\n {question} + + \nAnswered Sub-questions:\n {answered_sub_questions} + + \nContext:\n {context} \n\n + \n\n + + Answer:""" + +ENTITY_TERM_PROMPT = """ \n + Based on the original question and the context retieved from a dataset, please generate a list of + entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts + (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. + + \n\n + Here is the original question: + \n ------- \n + {question} + \n ------- \n + And here is the context retrieved: + \n ------- \n + {context} + \n ------- \n + + Please format your answer as a json object in the following format: + + {{"retrieved_entities_relationships": {{ + "entities": [{{ + "entity_name": , + "entity_type": + }}], + "relationships": [{{ + "name": , + "type": , + "entities": [, ] + }}], + "terms": [{{ + "term_name": , + "term_type": , + "similar_to": + }}] + }} + }} + """ diff --git a/backend/danswer/agent_search/shared_graph_utils/utils.py b/backend/danswer/agent_search/shared_graph_utils/utils.py index 95faf287ae0..24c505ac585 100644 --- a/backend/danswer/agent_search/shared_graph_utils/utils.py +++ b/backend/danswer/agent_search/shared_graph_utils/utils.py @@ -22,12 +22,22 @@ def format_docs(docs: Sequence[InferenceSection]) -> str: def clean_and_parse_list_string(json_string: str) -> list[dict]: + # Remove any prefixes/labels before the actual JSON content + json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL) + # Remove markdown code block markers and any newline prefixes cleaned_string = re.sub(r"```json\n|\n```", "", json_string) cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") cleaned_string = " ".join(cleaned_string.split()) - # Parse the cleaned string into a Python dictionary - return ast.literal_eval(cleaned_string) + + # Try parsing with json.loads first, fall back to ast.literal_eval + try: + return json.loads(cleaned_string) + except json.JSONDecodeError: + try: + return ast.literal_eval(cleaned_string) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e def clean_and_parse_json_string(json_string: str) -> dict[str, Any]: diff --git a/backend/danswer/agent_search/test.py b/backend/danswer/agent_search/test.py deleted file mode 100644 index 1f760e4dc90..00000000000 --- a/backend/danswer/agent_search/test.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Annotated -from typing import Literal -from typing import TypedDict - -from dotenv import load_dotenv -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph -from langgraph.types import Send - - -def unique_concat(a: list[str], b: list[str]) -> list[str]: - combined = a + b - return list(set(combined)) - - -load_dotenv(".vscode/.env") - - -class InputState(TypedDict): - user_input: str - # str_arr: list[str] - - -class OutputState(TypedDict): - graph_output: str - - -class SharedState(TypedDict): - llm: int - - -class OverallState(TypedDict): - foo: str - user_input: str - str_arr: Annotated[list[str], unique_concat] - # str_arr: list[str] - - -class PrivateState(TypedDict): - foo: str - bar: str - - -def conditional_edge_3(state: PrivateState) -> Literal["node_4"]: - print(f"conditional_edge_3: {state}") - return Send( - "node_4", - state, - ) - - -def node_1(state: OverallState): - print(f"node_1: {state}") - return { - "foo": state["user_input"] + " name", - "user_input": state["user_input"], - "str_arr": ["a", "b", "c"], - } - - -def node_2(state: OverallState): - print(f"node_2: {state}") - return { - "foo": "foo", - "bar": "bar", - "test1": "test1", - "str_arr": ["a", "d", "e", "f"], - } - - -def node_3(state: PrivateState): - print(f"node_3: {state}") - return {"bar": state["bar"] + " Lance"} - - -def node_4(state: PrivateState): - print(f"node_4: {state}") - return { - "foo": state["bar"] + " more bar", - } - - -def node_5(state: OverallState): - print(f"node_5: {state}") - updated_aggregate = [item for item in state["str_arr"] if "b" not in item] - print(f"updated_aggregate: {updated_aggregate}") - return {"str_arr": updated_aggregate} - - -builder = StateGraph( - state_schema=OverallState, - # input=InputState, - # output=OutputState -) -builder.add_node("node_1", node_1) -builder.add_node("node_2", node_2) -builder.add_node("node_3", node_3) -builder.add_node("node_4", node_4) -builder.add_node("node_5", node_5) -builder.add_edge(START, "node_1") -builder.add_edge("node_1", "node_2") -builder.add_edge("node_2", "node_3") -builder.add_conditional_edges( - source="node_3", - path=conditional_edge_3, -) -builder.add_edge("node_4", "node_5") -builder.add_edge("node_5", END) -graph = builder.compile() -# output = graph.invoke( -# {"user_input":"My"}, -# stream_mode="values", -# ) -for chunk in graph.stream( - {"user_input": "My"}, - stream_mode="debug", -): - print() - print(chunk) diff --git a/backend/danswer/utils/timing.py b/backend/danswer/utils/timing.py index 0d4eb7a14d4..4aa2a8e5483 100644 --- a/backend/danswer/utils/timing.py +++ b/backend/danswer/utils/timing.py @@ -33,12 +33,12 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any: elapsed_time_str = f"{elapsed_time:.3f}" log_name = func_name or func.__name__ args_str = f" args={args} kwargs={kwargs}" if include_args else "" - final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" - if debug_only: - logger.debug(final_log) - else: - # These are generally more important logs so the level is a bit higher - logger.notice(final_log) + f"{log_name}{args_str} took {elapsed_time_str} seconds" + # if debug_only: + # logger.debug(final_log) + # else: + # # These are generally more important logs so the level is a bit higher + # logger.notice(final_log) if not print_only: optional_telemetry( From 0c13c9a106e4af6cb0a1038aa2517b9717bb5579 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Sun, 15 Dec 2024 15:12:55 -0800 Subject: [PATCH 10/10] updates --- .../answer_query/nodes/answer_generation.py | 2 +- .../answer_query/nodes/format_answer.py | 2 +- .../agent_search/answer_query/states.py | 2 +- .../agent_search/expanded_retrieval/edges.py | 6 ----- .../expanded_retrieval/graph_builder.py | 11 +++------- .../agent_search/expanded_retrieval/states.py | 2 +- .../agent_search/main/graph_builder.py | 19 ++++++++-------- .../main/nodes/generate_initial_answer.py | 8 ++++--- .../shared_graph_utils/prompts.py | 22 ++++++++++++++----- backend/requirements/default.txt | 4 ++-- 10 files changed, 39 insertions(+), 39 deletions(-) diff --git a/backend/danswer/agent_search/answer_query/nodes/answer_generation.py b/backend/danswer/agent_search/answer_query/nodes/answer_generation.py index 268558e08ff..c33d0619638 100644 --- a/backend/danswer/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/danswer/agent_search/answer_query/nodes/answer_generation.py @@ -9,7 +9,7 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: query = state["query_to_answer"] - docs = state["documents"] + docs = state["reordered_documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/danswer/agent_search/answer_query/nodes/format_answer.py b/backend/danswer/agent_search/answer_query/nodes/format_answer.py index a3a06ac7347..117d157d08d 100644 --- a/backend/danswer/agent_search/answer_query/nodes/format_answer.py +++ b/backend/danswer/agent_search/answer_query/nodes/format_answer.py @@ -10,7 +10,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: query=state["query_to_answer"], quality=state["answer_quality"], answer=state["answer"], - documents=state["documents"], + documents=state["reordered_documents"], ) ], ) diff --git a/backend/danswer/agent_search/answer_query/states.py b/backend/danswer/agent_search/answer_query/states.py index bb9ac37d40e..a7973cf59e0 100644 --- a/backend/danswer/agent_search/answer_query/states.py +++ b/backend/danswer/agent_search/answer_query/states.py @@ -24,7 +24,7 @@ class QAGenerationOutput(TypedDict, total=False): class ExpandedRetrievalOutput(TypedDict): - documents: Annotated[list[InferenceSection], dedup_inference_sections] + reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] class AnswerQueryState( diff --git a/backend/danswer/agent_search/expanded_retrieval/edges.py b/backend/danswer/agent_search/expanded_retrieval/edges.py index 80c7dabb522..b4b8a88ffb7 100644 --- a/backend/danswer/agent_search/expanded_retrieval/edges.py +++ b/backend/danswer/agent_search/expanded_retrieval/edges.py @@ -6,7 +6,6 @@ from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI from danswer.llm.interfaces import LLM @@ -43,8 +42,3 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab ) for query in rewritten_queries ] - - -def conditionally_rerank_edge(state: ExpandedRetrievalState) -> bool: - print(f"conditionally_rerank_edge state: {state.keys()}") - return bool(state["search_request"].rerank_settings) diff --git a/backend/danswer/agent_search/expanded_retrieval/graph_builder.py b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py index 68fec83995f..4c8f421eb79 100644 --- a/backend/danswer/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/danswer/agent_search/expanded_retrieval/graph_builder.py @@ -2,7 +2,6 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from danswer.agent_search.expanded_retrieval.edges import conditionally_rerank_edge from danswer.agent_search.expanded_retrieval.edges import parallel_retrieval_edge from danswer.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval @@ -54,13 +53,9 @@ def expanded_retrieval_graph_builder() -> StateGraph: start_key="doc_retrieval", end_key="verification_kickoff", ) - graph.add_conditional_edges( - source="doc_verification", - path=conditionally_rerank_edge, - path_map={ - True: "doc_reranking", - False: END, - }, + graph.add_edge( + start_key="doc_verification", + end_key="doc_reranking", ) graph.add_edge( start_key="doc_reranking", diff --git a/backend/danswer/agent_search/expanded_retrieval/states.py b/backend/danswer/agent_search/expanded_retrieval/states.py index 04811c32652..73b0c8713eb 100644 --- a/backend/danswer/agent_search/expanded_retrieval/states.py +++ b/backend/danswer/agent_search/expanded_retrieval/states.py @@ -33,4 +33,4 @@ class ExpandedRetrievalInput(PrimaryState, total=True): class ExpandedRetrievalOutput(TypedDict): - documents: Annotated[list[InferenceSection], dedup_inference_sections] + reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/danswer/agent_search/main/graph_builder.py b/backend/danswer/agent_search/main/graph_builder.py index 6c783e619d2..6a282639e97 100644 --- a/backend/danswer/agent_search/main/graph_builder.py +++ b/backend/danswer/agent_search/main/graph_builder.py @@ -47,10 +47,6 @@ def main_graph_builder() -> StateGraph: start_key=START, end_key="expanded_retrieval", ) - graph.add_edge( - start_key="expanded_retrieval", - end_key="generate_initial_answer", - ) graph.add_edge( start_key=START, @@ -62,7 +58,7 @@ def main_graph_builder() -> StateGraph: path_map=["answer_query"], ) graph.add_edge( - start_key="answer_query", + start_key=["answer_query", "expanded_retrieval"], end_key="generate_initial_answer", ) graph.add_edge( @@ -82,7 +78,7 @@ def main_graph_builder() -> StateGraph: compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="Who made Excel and what other products did they make?", + query="If i am familiar with the function that I need, how can I type it into a cell?", ) with get_session_context_manager() as db_session: inputs = MainInput( @@ -91,9 +87,12 @@ def main_graph_builder() -> StateGraph: fast_llm=fast_llm, db_session=db_session, ) - output = compiled_graph.invoke( + for thing in compiled_graph.stream( input=inputs, + # stream_mode="debug", # debug=True, - # subgraphs=True, - ) - print(output) + subgraphs=True, + ): + # print(thing) + print() + print() diff --git a/backend/danswer/agent_search/main/nodes/generate_initial_answer.py b/backend/danswer/agent_search/main/nodes/generate_initial_answer.py index 96e48bc1bda..92d70c062db 100644 --- a/backend/danswer/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/danswer/agent_search/main/nodes/generate_initial_answer.py @@ -2,7 +2,7 @@ from danswer.agent_search.main.states import InitialAnswerOutput from danswer.agent_search.main.states import MainState -from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT +from danswer.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT from danswer.agent_search.shared_graph_utils.utils import format_docs @@ -21,7 +21,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: """ for decomp_answer_result in decomp_answer_results: if ( - decomp_answer_result.quality == "yes" + decomp_answer_result.quality.lower() == "yes" and len(decomp_answer_result.answer) > 0 and decomp_answer_result.answer != "I don't know" ): @@ -47,5 +47,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: # Grader model = state["fast_llm"] response = model.invoke(msg) + answer = response.pretty_repr() - return InitialAnswerOutput(initial_answer=response.pretty_repr()) + print(answer) + return InitialAnswerOutput(initial_answer=answer) diff --git a/backend/danswer/agent_search/shared_graph_utils/prompts.py b/backend/danswer/agent_search/shared_graph_utils/prompts.py index f5b40a0ecb0..a3eeba29fb9 100644 --- a/backend/danswer/agent_search/shared_graph_utils/prompts.py +++ b/backend/danswer/agent_search/shared_graph_utils/prompts.py @@ -1,12 +1,22 @@ -REWRITE_PROMPT_MULTI = """ \n - Please convert an initial user question into a 2-3 more appropriate search queries for retrievel from a - document store. \n +REWRITE_PROMPT_MULTI_ORIGINAL = """ \n + Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a + document store. Particularly, try to think about resolving ambiguities and make the search queries more specific, + enabling the system to search more broadly. + Also, try to make the search queries not redundant, i.e. not too similar! \n\n Here is the initial question: \n ------- \n {question} \n ------- \n + Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """ - Formulate the query: """ +REWRITE_PROMPT_MULTI = """ \n + Please create a list of 2-3 sample documents that could answer an original question. Each document + should be about as long as the original question. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """ BASE_RAG_PROMPT = """ \n You are an assistant for question-answering tasks. Use the context provided below - and only the @@ -40,7 +50,7 @@ Please answer with yes or no:""" VERIFIER_PROMPT = """ \n - Please check whether the document seems to be relevant for the answer of the original question. Please + Please check whether the document seems to be relevant for the answer of the question. Please only answer with 'yes' or 'no' \n Here is the initial question: \n ------- \n @@ -330,7 +340,7 @@ INITIAL_DECOMPOSITION_PROMPT = """ \n - Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to answer the original question. The purpose for this decomposition is to isolate individulal entities (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 8a61115147b..01a99c975fd 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -27,13 +27,13 @@ jira==3.5.1 jsonref==1.1.0 trafilatura==1.12.2 langchain==0.3.7 -langchain-core==0.3.20 +langchain-core==0.3.24 langchain-openai==0.2.9 langchain-text-splitters==0.3.2 langchainhub==0.1.21 langgraph==0.2.59 langgraph-checkpoint==2.0.5 -langgraph-sdk==0.1.36 +langgraph-sdk==0.1.44 litellm==1.53.1 lxml==5.3.0 lxml_html_clean==0.2.2