Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ContextAugmentedFunctionAgent] multi tool function call with retrieval #148

Closed
wants to merge 45 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7e700d4
Fix conflicts
w5688414 Dec 12, 2023
62cab2c
multi tool function call with retrieval
w5688414 Dec 12, 2023
9a63a99
Update to _async_run_llm_without_hooks
w5688414 Dec 13, 2023
039f53b
Add direct prompts
w5688414 Dec 13, 2023
bc07951
Update prompts
w5688414 Dec 13, 2023
4b14f85
Add context augmented retrieval agent
w5688414 Dec 15, 2023
a3ca306
Remove type error
w5688414 Dec 15, 2023
1f4674f
Add direct prompt
w5688414 Dec 15, 2023
775c363
Add knowledge base tools impl
w5688414 Dec 15, 2023
f028c84
Add retriever config
w5688414 Dec 15, 2023
b3b1a99
Update tool retrieval
w5688414 Dec 18, 2023
67066da
Update tool retrieval
w5688414 Dec 18, 2023
51c0137
Add system prompt
w5688414 Dec 19, 2023
3320e7b
Add faiss indexing
w5688414 Dec 19, 2023
76604f8
Update erniebot-agent/erniebot_agent/agents/functional_agent_with_ret…
w5688414 Dec 19, 2023
e28cf0b
Add planning and execute rules
w5688414 Dec 19, 2023
a6d331a
Update planning logic
w5688414 Dec 20, 2023
a78fa66
Add automatic prompt engineer
w5688414 Dec 20, 2023
c25e6aa
Update prompt engineer
w5688414 Dec 21, 2023
ff980ba
Add prompt agent
w5688414 Dec 21, 2023
cc37313
Fix conflicts
w5688414 Jan 16, 2024
f6ba05b
Resolve conflicts
w5688414 Jan 16, 2024
e37f2bf
Fix conflicts
w5688414 Jan 16, 2024
6b95d87
Update retrieval tools
w5688414 Jan 16, 2024
254df32
Update name
w5688414 Jan 16, 2024
1b059d6
Update format
w5688414 Jan 16, 2024
add83cd
Update unitests
w5688414 Jan 17, 2024
b915176
remove functional_agent_with_retrieval_example.py
w5688414 Jan 17, 2024
d4a034d
Update function_agent_with_retrieval.py
w5688414 Jan 17, 2024
30d8081
Update ContextAugmentedFunctionalAgent
w5688414 Jan 17, 2024
d97ee9d
Update prompt agent
w5688414 Jan 17, 2024
2c39536
Update prompt agent
w5688414 Jan 17, 2024
4534b1d
remove prompt agent
w5688414 Jan 18, 2024
ce74d54
delete faiss_util.py
w5688414 Jan 18, 2024
f8b0450
Update
w5688414 Jan 18, 2024
a59f84e
Merge branch 'develop' of https://github.com/PaddlePaddle/ERNIE-Bot-S…
w5688414 Jan 18, 2024
c104795
delete knowledge_tools_example
w5688414 Jan 18, 2024
c85c7c4
Update format
w5688414 Jan 18, 2024
6740056
remove FunctionAgentWithQueryPlanning
w5688414 Jan 19, 2024
ccea25b
Update retrieval agents
w5688414 Jan 19, 2024
4551c7c
Add unitests
w5688414 Jan 19, 2024
f17e225
Update steps_taken
w5688414 Jan 19, 2024
b24e685
reformat
w5688414 Jan 19, 2024
a860def
Update erniebot
w5688414 Jan 19, 2024
e3dd45f
Update unitest
w5688414 Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions erniebot-agent/erniebot_agent/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .base import Agent
from .functional_agent import FunctionalAgent
from .functional_agent_with_retrieval import (
ContextAugmentedFunctionalAgent,
FunctionalAgentWithQueryPlanning,
FunctionalAgentWithRetrieval,
FunctionalAgentWithRetrievalScoreTool,
FunctionalAgentWithRetrievalTool,
Expand Down
148 changes: 134 additions & 14 deletions erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A

chat_history.append(HumanMessage(content=prompt))

step_input = HumanMessage(
content=self.rag_prompt.format(query=prompt, documents=results["documents"])
)
fake_chat_history: List[Message] = []
fake_chat_history.append(step_input)
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
llm_resp = await self._async_run_llm_without_hooks(
messages=fake_chat_history,
functions=None,
system=self.system_message.content if self.system_message is not None else None,
)

# Get RAG results
output_message = llm_resp.message

outputs = []
for item in results["documents"]:
outputs.append(
Expand All @@ -172,7 +186,8 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A
actions_taken.append(action)
# return response
tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False)
next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json)
# next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json)
next_step_input = FunctionMessage(name=action.tool_name, content=output_message.content)
tool_resp = ToolResponse(json=tool_ret_json, files=[])
await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp)

Expand Down Expand Up @@ -240,7 +255,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A
await self._callback_manager.on_tool_start(
agent=self, tool=self.search_tool, input_args=tool_args
)
chat_history.append(HumanMessage(content=prompt))

outputs = []
for item in results["documents"]:
outputs.append(
Expand All @@ -251,25 +266,97 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A
}
)

chat_history.append(
AIMessage(
content="",
function_call={
"name": "KnowledgeBaseTool",
"thoughts": "这是一个检索的需求,我需要在KnowledgeBaseTool知识库中检索出与输入的query相关的段落,并返回给用户。",
"arguments": tool_args,
},
# return response
tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False)
# Direct Prompt
next_step_input = HumanMessage(content=f"问题:{prompt},要求:请在第一步执行检索的操作,并且检索只允许调用一次")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FunctionalAgentWithRetrievalScoreTool和FunctionalAgentWithRetrievalTool 完全分不出来了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FunctionAgentWithRetrievalTool是模拟一次function call的调用
FunctionAgentWithRetrievalScoreTool是通过prompt的方式来引导function call调用检索

tool_resp = ToolResponse(json=tool_ret_json, files=[])
await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp)

num_steps_taken = 0
while num_steps_taken < self.max_steps:
curr_step_output = await self._async_step(
next_step_input, chat_history, actions_taken, files_involved
)
if curr_step_output is None:
response = self._create_finished_response(chat_history, actions_taken, files_involved)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
num_steps_taken += 1
response = self._create_stopped_response(chat_history, actions_taken, files_involved)
return response
else:
logger.info(
f"Irrelevant retrieval results. Fallbacking to FunctionalAgent for the query: {prompt}"
)
return await super()._async_run(prompt)

async def _maybe_retrieval(
self,
step_input,
):
documents = self.knowledge_base.search(step_input, top_k=self.top_k, filters=None)
documents = [item for item in documents if item["score"] > self.threshold]
results = {}
results["documents"] = documents
return results


class ContextAugmentedFunctionalAgent(FunctionalAgent):
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs):
super().__init__(**kwargs)
self.knowledge_base = knowledge_base
self.top_k = top_k
self.threshold = threshold
self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"])
self.search_tool = KnowledgeBaseTool()

async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse:
results = await self._maybe_retrieval(prompt)
if len(results["documents"]) > 0:
# RAG
chat_history: List[Message] = []
actions_taken: List[AgentAction] = []
files_involved: List[AgentFile] = []

tool_args = json.dumps({"query": prompt}, ensure_ascii=False)
await self._callback_manager.on_tool_start(
agent=self, tool=self.search_tool, input_args=tool_args
)
step_input = HumanMessage(
content=self.rag_prompt.format(query=prompt, documents=results["documents"])
)
fake_chat_history: List[Message] = []
fake_chat_history.append(step_input)
llm_resp = await self._async_run_llm_without_hooks(
messages=fake_chat_history,
functions=None,
system=self.system_message.content if self.system_message is not None else None,
)

# Get RAG results
output_message = llm_resp.message

outputs = []
for item in results["documents"]:
outputs.append(
{
"id": item["id"],
"title": item["title"],
"document": item["content_se"],
}
)

# Knowledge Retrieval Tool
action = AgentAction(tool_name="KnowledgeBaseTool", tool_args=tool_args)
actions_taken.append(action)
# return response
tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False)
next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json)
# 会有无限循环调用工具的问题
next_step_input = HumanMessage(
content=f"背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}"
)
tool_resp = ToolResponse(json=tool_ret_json, files=[])
await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp)

num_steps_taken = 0
while num_steps_taken < self.max_steps:
curr_step_output = await self._async_step(
Expand Down Expand Up @@ -298,3 +385,36 @@ async def _maybe_retrieval(
results = {}
results["documents"] = documents
return results


class FunctionalAgentWithQueryPlanning(FunctionalAgent):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是个啥..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是把retrieval 当成tool的类,planning部分用的是functional agent默认的,可能后面会加入一些后处理

def __init__(self, top_k: int = 3, threshold: float = 0.1, **kwargs):
super().__init__(**kwargs)
self.top_k = top_k
self.threshold = threshold
self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"])
# self.search_tool = KnowledgeBaseTool()

async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse:
# RAG
chat_history: List[Message] = []
actions_taken: List[AgentAction] = []
files_involved: List[AgentFile] = []
# 会有无限循环调用工具的问题
# next_step_input = HumanMessage(
# content=f"请选择合适的工具来回答:{prompt},如果需要的话,可以对把问题分解成子问题,然后每个子问题选择合适的工具回答。"
# )
next_step_input = HumanMessage(content=prompt)
num_steps_taken = 0
while num_steps_taken < self.max_steps:
curr_step_output = await self._async_step(
next_step_input, chat_history, actions_taken, files_involved
)
if curr_step_output is None:
response = self._create_finished_response(chat_history, actions_taken, files_involved)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
num_steps_taken += 1
response = self._create_stopped_response(chat_history, actions_taken, files_involved)
return response
7 changes: 4 additions & 3 deletions erniebot-agent/erniebot_agent/tools/baizhong_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List, Optional, Type
from typing import Any, Dict, List, Optional, Type

from erniebot_agent.messages import AIMessage, HumanMessage
from erniebot_agent.tools.schema import ToolParameterView
Expand Down Expand Up @@ -30,9 +30,10 @@ class BaizhongSearchTool(Tool):
ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView

def __init__(
self, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None
self, name, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None
) -> None:
super().__init__()
self.name = name
self.db = db
self.description = description
self.few_shot_examples = []
Expand All @@ -44,7 +45,7 @@ def __init__(
self.few_shot_examples = examples
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[dict[str, Any]] = None):
async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
documents = self.db.search(query, top_k, filters)
documents = [item for item in documents if item["score"] > self.threshold]
return {"documents": documents}
Expand Down
76 changes: 76 additions & 0 deletions erniebot-agent/erniebot_agent/tools/openai_search_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Type

from erniebot_agent.messages import AIMessage, HumanMessage
from erniebot_agent.tools.schema import ToolParameterView
from pydantic import Field

from .base import Tool


class OpenAISearchToolInputView(ToolParameterView):
query: str = Field(description="查询语句")
top_k: int = Field(description="返回结果数量")


class SearchResponseDocument(ToolParameterView):
title: str = Field(description="检索结果的标题")
document: str = Field(description="检索结果的内容")


class OpenAISearchToolOutputView(ToolParameterView):
documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落")


class OpenAISearchTool(Tool):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类和OpenAI有啥关系呀?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面的db用的是langchain 的faiss,faiss的search接口和返回的参数跟欧若拉不一样,所以重新写了一个tool

description: str = "在知识库中检索与用户输入query相关的段落"
input_type: Type[ToolParameterView] = OpenAISearchToolInputView
ouptut_type: Type[ToolParameterView] = OpenAISearchToolOutputView

def __init__(
self, name, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None
) -> None:
super().__init__()
self.name = name
self.db = db
self.description = description
self.few_shot_examples = []
if input_type is not None:
self.input_type = input_type
if output_type is not None:
self.ouptut_type = output_type
if examples is not None:
self.few_shot_examples = examples
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
documents = self.db.similarity_search_with_relevance_scores(query, top_k)
docs = []
for doc, score in documents:
if score > self.threshold:
docs.append(
{"document": doc.page_content, "title": doc.metadata["source"], "meta": doc.metadata}
)

return {"documents": docs}

@property
def examples(
self,
) -> List[Any]:
few_shot_objects: List[Any] = []
for item in self.few_shot_examples:
few_shot_objects.append(HumanMessage(item["user"]))
few_shot_objects.append(
AIMessage(
"",
function_call={
"name": self.tool_name,
"thoughts": item["thoughts"],
"arguments": item["arguments"],
},
)
)

return few_shot_objects
Loading