Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Initial Implementation of the Agent Search Graph #3299

Open
wants to merge 12 commits into
base: agent-search-a
Choose a base branch
from
100 changes: 100 additions & 0 deletions backend/danswer/agent_search/answer_query/graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph

from danswer.agent_search.answer_query.nodes.answer_check import answer_check
from danswer.agent_search.answer_query.nodes.answer_generation import answer_generation
from danswer.agent_search.answer_query.nodes.format_answer import format_answer
from danswer.agent_search.answer_query.states import AnswerQueryInput
from danswer.agent_search.answer_query.states import AnswerQueryOutput
from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)


def answer_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQueryState,
input=AnswerQueryInput,
output=AnswerQueryOutput,
)

### Add nodes ###

expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="expanded_retrieval_for_initial_decomp",
action=expanded_retrieval,
)
graph.add_node(
node="answer_check",
action=answer_check,
)
graph.add_node(
node="answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_answer",
action=format_answer,
)

### Add edges ###

graph.add_edge(
start_key=START,
end_key="expanded_retrieval_for_initial_decomp",
)
graph.add_edge(
start_key="expanded_retrieval_for_initial_decomp",
end_key="answer_generation",
)
graph.add_edge(
start_key="answer_generation",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)

return graph


if __name__ == "__main__":
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
from danswer.context.search.models import SearchRequest

graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="Who made Excel and what other products did they make?",
)
with get_session_context_manager() as db_session:
inputs = AnswerQueryInput(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
query_to_answer="Who made Excel?",
)
output = compiled_graph.invoke(
input=inputs,
# debug=True,
# subgraphs=True,
)
print(output)
# for namespace, chunk in compiled_graph.stream(
# input=inputs,
# # debug=True,
# subgraphs=True,
# ):
# print(namespace)
# print(chunk)
30 changes: 30 additions & 0 deletions backend/danswer/agent_search/answer_query/nodes/answer_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs

from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.answer_query.states import QACheckOutput
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT


def answer_check(state: AnswerQueryState) -> QACheckOutput:
msg = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(
question=state["search_request"].query,
base_answer=state["answer"],
)
)
]

fast_llm = state["fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)

response_str = merge_message_runs(response, chunk_separator="")[0].content

return QACheckOutput(
answer_quality=response_str,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs

from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.answer_query.states import QAGenerationOutput
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs


def answer_generation(state: AnswerQueryState) -> QAGenerationOutput:
query = state["query_to_answer"]
docs = state["reordered_documents"]

print(f"Number of verified retrieval docs: {len(docs)}")

msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs))
)
]

fast_llm = state["fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)

answer_str = merge_message_runs(response, chunk_separator="")[0].content
return QAGenerationOutput(
answer=answer_str,
)
16 changes: 16 additions & 0 deletions backend/danswer/agent_search/answer_query/nodes/format_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from danswer.agent_search.answer_query.states import AnswerQueryOutput
from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.answer_query.states import SearchAnswerResults


def format_answer(state: AnswerQueryState) -> AnswerQueryOutput:
return AnswerQueryOutput(
decomp_answer_results=[
SearchAnswerResults(
query=state["query_to_answer"],
quality=state["answer_quality"],
answer=state["answer"],
documents=state["reordered_documents"],
)
],
)
45 changes: 45 additions & 0 deletions backend/danswer/agent_search/answer_query/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Annotated
from typing import TypedDict

from pydantic import BaseModel

from danswer.agent_search.core_state import PrimaryState
from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections
from danswer.context.search.models import InferenceSection


class SearchAnswerResults(BaseModel):
query: str
answer: str
quality: str
documents: Annotated[list[InferenceSection], dedup_inference_sections]


class QACheckOutput(TypedDict, total=False):
answer_quality: str


class QAGenerationOutput(TypedDict, total=False):
answer: str


class ExpandedRetrievalOutput(TypedDict):
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]


class AnswerQueryState(
PrimaryState,
QACheckOutput,
QAGenerationOutput,
ExpandedRetrievalOutput,
total=True,
):
query_to_answer: str


class AnswerQueryInput(PrimaryState, total=True):
query_to_answer: str


class AnswerQueryOutput(TypedDict):
decomp_answer_results: list[SearchAnswerResults]
15 changes: 15 additions & 0 deletions backend/danswer/agent_search/core_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import TypedDict

from sqlalchemy.orm import Session

from danswer.context.search.models import SearchRequest
from danswer.llm.interfaces import LLM


class PrimaryState(TypedDict, total=False):
search_request: SearchRequest
primary_llm: LLM
fast_llm: LLM
# a single session for the entire agent search
# is fine if we are only reading
db_session: Session
Empty file.
Empty file.
114 changes: 114 additions & 0 deletions backend/danswer/agent_search/deep_answer/nodes/answer_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any

from langchain_core.messages import HumanMessage

from danswer.agent_search.main.states import MainState
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace


# aggregate sub questions and answers
def deep_answer_generation(state: MainState) -> dict[str, Any]:
"""
Generate answer

Args:
state (messages): The current state

Returns:
dict: The updated state with re-phrased question
"""
print("---DEEP GENERATE---")

question = state["original_question"]
docs = state["deduped_retrieval_docs"]

deep_answer_context = state["core_answer_dynamic_context"]

print(f"Number of verified retrieval docs - deep: {len(docs)}")

combined_context = normalize_whitespace(
COMBINED_CONTEXT.format(
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
)
)

msg = [
HumanMessage(
content=MODIFIED_RAG_PROMPT.format(
question=question, combined_context=combined_context
)
)
]

# Grader
model = state["fast_llm"]
response = model.invoke(msg)

return {
"deep_answer": response.content,
}


def final_stuff(state: MainState) -> dict[str, Any]:
"""
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.

Args:
state (messages): The current state

Returns:
dict: The updated state with the agent response appended to messages
"""
print("---FINAL---")

messages = state["log_messages"]
time_ordered_messages = [x.pretty_repr() for x in messages]
time_ordered_messages.sort()

print("Message Log:")
print("\n".join(time_ordered_messages))

initial_sub_qas = state["initial_sub_qas"]
initial_sub_qa_list = []
for initial_sub_qa in initial_sub_qas:
if initial_sub_qa["sub_answer_check"] == "yes":
initial_sub_qa_list.append(
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
)

initial_sub_qa_context = "\n".join(initial_sub_qa_list)

base_answer = state["base_answer"]

print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
print("--------------------------------")

if not state.get("deep_answer"):
print("No Deep Answer was required")
return {}

deep_answer = state["deep_answer"]
sub_qas = state["sub_qas"]
sub_qa_list = []
for sub_qa in sub_qas:
if sub_qa["sub_answer_check"] == "yes":
sub_qa_list.append(
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
)

sub_qa_context = "\n".join(sub_qa_list)

print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Final Deep Answer:\n{deep_answer}")
print("--------------------------------")
print("Sub Questions and Answers:")
print(sub_qa_context)

return {}
Loading
Loading