From 7e700d49450c286e15cdaddfe2c761e32bd13ea4 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 12 Dec 2023 23:30:19 +0800 Subject: [PATCH 01/43] Fix conflicts --- .../agents/functional_agent_with_retrieval.py | 50 +++++++++---- ...functional_agent_with_retrieval_example.py | 70 ++++++++++++++----- 2 files changed, 89 insertions(+), 31 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 858e91d92..1ad2f3c43 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -146,6 +146,20 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A chat_history.append(HumanMessage(content=prompt)) + step_input = HumanMessage( + content=self.rag_prompt.format(query=prompt, documents=results["documents"]) + ) + fake_chat_history: List[Message] = [] + fake_chat_history.append(step_input) + llm_resp = await self._async_run_llm( + messages=fake_chat_history, + functions=None, + system=self.system_message.content if self.system_message is not None else None, + ) + + # Get RAG results + output_message = llm_resp.message + outputs = [] for item in results["documents"]: outputs.append( @@ -172,7 +186,8 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A actions_taken.append(action) # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json) + # next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json) + next_step_input = FunctionMessage(name=action.tool_name, content=output_message.content) tool_resp = ToolResponse(json=tool_ret_json, files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) @@ -241,6 +256,20 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A agent=self, tool=self.search_tool, input_args=tool_args ) chat_history.append(HumanMessage(content=prompt)) + + step_input = HumanMessage( + content=self.rag_prompt.format(query=prompt, documents=results["documents"]) + ) + fake_chat_history: List[Message] = [] + fake_chat_history.append(step_input) + llm_resp = await self._async_run_llm( + messages=fake_chat_history, + functions=None, + system=self.system_message.content if self.system_message is not None else None, + ) + + # Get RAG results + output_message = llm_resp.message outputs = [] for item in results["documents"]: outputs.append( @@ -251,23 +280,16 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A } ) - chat_history.append( - AIMessage( - content="", - function_call={ - "name": "KnowledgeBaseTool", - "thoughts": "这是一个检索的需求,我需要在KnowledgeBaseTool知识库中检索出与输入的query相关的段落,并返回给用户。", - "arguments": tool_args, - }, - ) - ) + chat_history.append(AIMessage(content=output_message.content, function_call=None)) # Knowledge Retrieval Tool - action = AgentAction(tool_name="KnowledgeBaseTool", tool_args=tool_args) - actions_taken.append(action) + # action = AgentAction(tool_name="BaizhongSearchTool", tool_args=tool_args) + # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json) + + next_step_input = HumanMessage(content=prompt) + tool_resp = ToolResponse(json=tool_ret_json, files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) num_steps_taken = 0 diff --git a/erniebot-agent/examples/functional_agent_with_retrieval_example.py b/erniebot-agent/examples/functional_agent_with_retrieval_example.py index 374fc3eef..c12f5dce9 100644 --- a/erniebot-agent/examples/functional_agent_with_retrieval_example.py +++ b/erniebot-agent/examples/functional_agent_with_retrieval_example.py @@ -1,5 +1,6 @@ import argparse import asyncio +from typing import Dict, List, Type from erniebot_agent.agents import ( FunctionalAgentWithRetrieval, @@ -8,12 +9,15 @@ ) from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.memory import WholeMemory +from erniebot_agent.messages import AIMessage, HumanMessage, Message from erniebot_agent.retrieval import BaizhongSearch from erniebot_agent.retrieval.document import Document from erniebot_agent.tools.baizhong_tool import BaizhongSearchTool -from erniebot_agent.tools.base import RemoteToolkit +from erniebot_agent.tools.base import RemoteToolkit, Tool +from erniebot_agent.tools.schema import ToolParameterView from langchain.document_loaders import PyPDFDirectoryLoader from langchain.text_splitter import SpacyTextSplitter +from pydantic import Field from tqdm import tqdm import erniebot @@ -38,6 +42,38 @@ args = parser.parse_args() +class NotesToolInputView(ToolParameterView): + draft: str = Field(description="草稿文本") + + +class NotesToolOutputView(ToolParameterView): + draft_results: str = Field(description="草稿文本结果") + + +class NotesTool(Tool): + description: str = "笔记本,用于记录和保存信息的笔记本工具" + input_type: Type[ToolParameterView] = NotesToolInputView + ouptut_type: Type[ToolParameterView] = NotesToolOutputView + + async def __call__(self, draft: str) -> Dict[str, str]: + # TODO: save draft to database + return {"draft_results": "保存成功"} + + @property + def examples(self) -> List[Message]: + return [ + HumanMessage("OpenAI管理层变更会带来哪些影响?并请把搜索的内容添加到笔记本中"), + AIMessage( + "", + function_call={ + "name": self.tool_name, + "thoughts": f"用户想保存笔记,我可以使用{self.tool_name}工具来保存,其中`draft`字段的内容为:'搜索的草稿'。", + "arguments": '{"draft": "搜索的草稿"}', + }, + ), + ] + + def offline_ann(data_path, baizhong_db): loader = PyPDFDirectoryLoader(data_path) documents = loader.load() @@ -70,28 +106,28 @@ def offline_ann(data_path, baizhong_db): llm = ERNIEBot(model="ernie-bot", api_type="custom") - retrieval_tool = BaizhongSearchTool( - description="Use Baizhong Search to retrieve documents.", db=baizhong_db, threshold=0.1 - ) + retrieval_tool = BaizhongSearchTool(description="在知识库中检索相关的段落", db=baizhong_db, threshold=0.1) # agent = FunctionalAgentWithRetrievalTool( # llm=llm, knowledge_base=baizhong_db, top_k=3, tools=[NotesTool(), retrieval_tool], memory=memory # ) # queries = [ - # "请把飞桨这两个字添加到笔记本中", - # "OpenAI管理层变更会带来哪些影响?并请把搜索的内容添加到笔记本中", - # "OpenAI管理层变更会带来哪些影响?", - # "量化交易", - # "今天天气怎么样?", - # "abcabc", + # "请把飞桨这两个字添加到笔记本中", + # "OpenAI管理层变更会带来哪些影响?并请把搜索的内容添加到笔记本中", + # "OpenAI管理层变更会带来哪些影响?", + # "量化交易", + # "今天天气怎么样?", + # "abcabc", # ] + queries = [ - "量化交易", - "城市景观照明中有过度照明的规定是什么?", - "这几篇文档主要内容是什么?", - "今天天气怎么样?", - "abcabc", + # "量化交易", + # "城市景观照明中有过度照明的规定是什么?", + "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", + # "这几篇文档主要内容是什么?", + # "今天天气怎么样?", + # "abcabc", ] toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") for query in queries: @@ -101,7 +137,7 @@ def offline_ann(data_path, baizhong_db): llm=llm, knowledge_base=baizhong_db, top_k=3, - tools=toolkit.get_tools() + [retrieval_tool], + tools=toolkit.get_tools() + [NotesTool(), retrieval_tool], memory=memory, ) elif args.retrieval_type == "rag_tool": @@ -118,7 +154,7 @@ def offline_ann(data_path, baizhong_db): knowledge_base=baizhong_db, top_k=3, threshold=0.1, - tools=toolkit.get_tools() + [retrieval_tool], + tools=[NotesTool(), retrieval_tool], memory=memory, ) try: From 62cab2ca5c23f6e0ce4ecd3b2b2695f9afa00e98 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 12 Dec 2023 23:38:04 +0800 Subject: [PATCH 02/43] multi tool function call with retrieval --- .../agents/functional_agent_with_retrieval.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 1ad2f3c43..1ce33f071 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -255,7 +255,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A await self._callback_manager.on_tool_start( agent=self, tool=self.search_tool, input_args=tool_args ) - chat_history.append(HumanMessage(content=prompt)) + # chat_history.append(HumanMessage(content=prompt)) step_input = HumanMessage( content=self.rag_prompt.format(query=prompt, documents=results["documents"]) @@ -280,7 +280,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A } ) - chat_history.append(AIMessage(content=output_message.content, function_call=None)) + # chat_history.append(AIMessage(content=output_message.content, function_call=None)) # Knowledge Retrieval Tool # action = AgentAction(tool_name="BaizhongSearchTool", tool_args=tool_args) @@ -288,7 +288,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - next_step_input = HumanMessage(content=prompt) + next_step_input = HumanMessage(content=f"背景:{output_message.content}, 问题:{prompt}") tool_resp = ToolResponse(json=tool_ret_json, files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) From 9a63a99057eae93eb894e3aac2375d738c5bc2e3 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 13 Dec 2023 08:31:36 +0800 Subject: [PATCH 03/43] Update to _async_run_llm_without_hooks --- .../agents/functional_agent_with_retrieval.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 1ce33f071..dd4b1f844 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -151,7 +151,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A ) fake_chat_history: List[Message] = [] fake_chat_history.append(step_input) - llm_resp = await self._async_run_llm( + llm_resp = await self._async_run_llm_without_hooks( messages=fake_chat_history, functions=None, system=self.system_message.content if self.system_message is not None else None, @@ -262,7 +262,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A ) fake_chat_history: List[Message] = [] fake_chat_history.append(step_input) - llm_resp = await self._async_run_llm( + llm_resp = await self._async_run_llm_without_hooks( messages=fake_chat_history, functions=None, system=self.system_message.content if self.system_message is not None else None, @@ -283,11 +283,11 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A # chat_history.append(AIMessage(content=output_message.content, function_call=None)) # Knowledge Retrieval Tool - # action = AgentAction(tool_name="BaizhongSearchTool", tool_args=tool_args) + # action = AgentAction(tool_name="KnowledgeBaseTool", tool_args=tool_args) # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - + # 这种做法会导致functional agent的retrieval tool持续触发 next_step_input = HumanMessage(content=f"背景:{output_message.content}, 问题:{prompt}") tool_resp = ToolResponse(json=tool_ret_json, files=[]) From 039f53b519d7ea231dd7306ca8fbe5096cd94fa8 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 13 Dec 2023 13:09:08 +0800 Subject: [PATCH 04/43] Add direct prompts --- .../agents/functional_agent_with_retrieval.py | 19 ++----------------- ...functional_agent_with_retrieval_example.py | 4 ++-- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index dd4b1f844..6afd5e37e 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -255,21 +255,12 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A await self._callback_manager.on_tool_start( agent=self, tool=self.search_tool, input_args=tool_args ) - # chat_history.append(HumanMessage(content=prompt)) - step_input = HumanMessage( content=self.rag_prompt.format(query=prompt, documents=results["documents"]) ) fake_chat_history: List[Message] = [] fake_chat_history.append(step_input) - llm_resp = await self._async_run_llm_without_hooks( - messages=fake_chat_history, - functions=None, - system=self.system_message.content if self.system_message is not None else None, - ) - # Get RAG results - output_message = llm_resp.message outputs = [] for item in results["documents"]: outputs.append( @@ -280,18 +271,12 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A } ) - # chat_history.append(AIMessage(content=output_message.content, function_call=None)) - - # Knowledge Retrieval Tool - # action = AgentAction(tool_name="KnowledgeBaseTool", tool_args=tool_args) - # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - # 这种做法会导致functional agent的retrieval tool持续触发 - next_step_input = HumanMessage(content=f"背景:{output_message.content}, 问题:{prompt}") - + next_step_input = HumanMessage(content=f"问题:{prompt},要求:请在第一步执行检索的操作") tool_resp = ToolResponse(json=tool_ret_json, files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) + num_steps_taken = 0 while num_steps_taken < self.max_steps: curr_step_output = await self._async_step( diff --git a/erniebot-agent/examples/functional_agent_with_retrieval_example.py b/erniebot-agent/examples/functional_agent_with_retrieval_example.py index c12f5dce9..d3e918703 100644 --- a/erniebot-agent/examples/functional_agent_with_retrieval_example.py +++ b/erniebot-agent/examples/functional_agent_with_retrieval_example.py @@ -122,10 +122,10 @@ def offline_ann(data_path, baizhong_db): # ] queries = [ - # "量化交易", + "量化交易", # "城市景观照明中有过度照明的规定是什么?", "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", - # "这几篇文档主要内容是什么?", + "这几篇文档主要内容是什么?", # "今天天气怎么样?", # "abcabc", ] From bc079518bc7eec85d7dab48d477aed040819c163 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 13 Dec 2023 14:20:20 +0800 Subject: [PATCH 05/43] Update prompts --- .../agents/functional_agent_with_retrieval.py | 2 +- .../examples/functional_agent_with_retrieval_example.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 6afd5e37e..2e1658b37 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -273,7 +273,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - next_step_input = HumanMessage(content=f"问题:{prompt},要求:请在第一步执行检索的操作") + next_step_input = HumanMessage(content=f"问题:{prompt},要求:请在第一步执行检索的操作,并且检索只允许调用一次") tool_resp = ToolResponse(json=tool_ret_json, files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) diff --git a/erniebot-agent/examples/functional_agent_with_retrieval_example.py b/erniebot-agent/examples/functional_agent_with_retrieval_example.py index d3e918703..fa23ea75d 100644 --- a/erniebot-agent/examples/functional_agent_with_retrieval_example.py +++ b/erniebot-agent/examples/functional_agent_with_retrieval_example.py @@ -123,11 +123,11 @@ def offline_ann(data_path, baizhong_db): queries = [ "量化交易", - # "城市景观照明中有过度照明的规定是什么?", + "城市景观照明中有过度照明的规定是什么?", "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", "这几篇文档主要内容是什么?", - # "今天天气怎么样?", - # "abcabc", + "今天天气怎么样?", + "abcabc", ] toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") for query in queries: From 4b14f85873a1bc106a0f01ad3c8c14a6e56eebbd Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 15 Dec 2023 09:53:57 +0000 Subject: [PATCH 06/43] Add context augmented retrieval agent --- .../erniebot_agent/agents/__init__.py | 1 + .../agents/functional_agent_with_retrieval.py | 83 +++++++++++++++++++ ...functional_agent_with_retrieval_example.py | 24 ++++-- 3 files changed, 101 insertions(+), 7 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/__init__.py b/erniebot-agent/erniebot_agent/agents/__init__.py index a824313b9..4fba64219 100644 --- a/erniebot-agent/erniebot_agent/agents/__init__.py +++ b/erniebot-agent/erniebot_agent/agents/__init__.py @@ -15,6 +15,7 @@ from .base import Agent from .functional_agent import FunctionalAgent from .functional_agent_with_retrieval import ( + ContextAugmentedFunctionalAgent, FunctionalAgentWithRetrieval, FunctionalAgentWithRetrievalScoreTool, FunctionalAgentWithRetrievalTool, diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 2e1658b37..91de2e162 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -305,3 +305,86 @@ async def _maybe_retrieval( results = {} results["documents"] = documents return results + + +class ContextAugmentedFunctionalAgent(FunctionalAgent): + def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs): + super().__init__(**kwargs) + self.knowledge_base = knowledge_base + self.top_k = top_k + self.threshold = threshold + self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) + self.search_tool = KnowledgeBaseTool() + + async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: + results = await self._maybe_retrieval(prompt) + if len(results["documents"]) > 0: + # RAG + chat_history: List[Message] = [] + actions_taken: List[AgentAction] = [] + files_involved: List[AgentFile] = [] + + tool_args = json.dumps({"query": prompt}, ensure_ascii=False) + await self._callback_manager.on_tool_start( + agent=self, tool=self.search_tool, input_args=tool_args + ) + step_input = HumanMessage( + content=self.rag_prompt.format(query=prompt, documents=results["documents"]) + ) + fake_chat_history: List[Message] = [] + fake_chat_history.append(step_input) + llm_resp = await self._async_run_llm_without_hooks( + messages=fake_chat_history, + functions=None, + system=self.system_message.content if self.system_message is not None else None, + ) + + # Get RAG results + output_message = llm_resp.message + + outputs = [] + for item in results["documents"]: + outputs.append( + { + "id": item["id"], + "title": item["title"], + "document": item["content_se"], + } + ) + + # return response + tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) + next_step_input = HumanMessage( + content=f"背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" + ) + tool_resp = ToolResponse(json=tool_ret_json, files=[]) + await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) + + num_steps_taken = 0 + while num_steps_taken < self.max_steps: + curr_step_output = await self._async_step( + next_step_input, chat_history, actions_taken, files_involved + ) + if curr_step_output is None: + response = self._create_finished_response(chat_history, actions_taken, files_involved) + self.memory.add_message(chat_history[0]) + self.memory.add_message(chat_history[-1]) + return response + num_steps_taken += 1 + response = self._create_stopped_response(chat_history, actions_taken, files_involved) + return response + else: + logger.info( + f"Irrelevant retrieval results. Fallbacking to FunctionalAgent for the query: {prompt}" + ) + return await super()._async_run(prompt) + + async def _maybe_retrieval( + self, + step_input, + ): + documents = self.knowledge_base.search(step_input, top_k=self.top_k, filters=None) + documents = [item for item in documents if item["score"] > self.threshold] + results = {} + results["documents"] = documents + return results diff --git a/erniebot-agent/examples/functional_agent_with_retrieval_example.py b/erniebot-agent/examples/functional_agent_with_retrieval_example.py index fa23ea75d..bbdd1bd59 100644 --- a/erniebot-agent/examples/functional_agent_with_retrieval_example.py +++ b/erniebot-agent/examples/functional_agent_with_retrieval_example.py @@ -3,6 +3,7 @@ from typing import Dict, List, Type from erniebot_agent.agents import ( + ContextAugmentedFunctionalAgent, FunctionalAgentWithRetrieval, FunctionalAgentWithRetrievalScoreTool, FunctionalAgentWithRetrievalTool, @@ -35,7 +36,7 @@ parser.add_argument("--project_id", default=-1, type=int, help="The API Key.") parser.add_argument( "--retrieval_type", - choices=["rag", "rag_tool", "rag_threshold"], + choices=["rag", "rag_tool", "rag_threshold", "context_aug"], default="rag", help="Retrieval type, default to rag.", ) @@ -57,7 +58,7 @@ class NotesTool(Tool): async def __call__(self, draft: str) -> Dict[str, str]: # TODO: save draft to database - return {"draft_results": "保存成功"} + return {"draft_results": "草稿在笔记本中保存成功"} @property def examples(self) -> List[Message]: @@ -122,12 +123,12 @@ def offline_ann(data_path, baizhong_db): # ] queries = [ - "量化交易", - "城市景观照明中有过度照明的规定是什么?", + # "量化交易", + # "城市景观照明中有过度照明的规定是什么?", "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", - "这几篇文档主要内容是什么?", - "今天天气怎么样?", - "abcabc", + # "这几篇文档主要内容是什么?", + # "今天天气怎么样?", + # "abcabc", ] toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") for query in queries: @@ -157,6 +158,15 @@ def offline_ann(data_path, baizhong_db): tools=[NotesTool(), retrieval_tool], memory=memory, ) + elif args.retrieval_type == "context_aug": + agent = ContextAugmentedFunctionalAgent( # type: ignore + llm=llm, + knowledge_base=baizhong_db, + top_k=3, + threshold=0.1, + tools=[NotesTool(), retrieval_tool], + memory=memory, + ) try: response = asyncio.run(agent.async_run(query)) print(f"query: {query}") From a3ca306872aa6553797f6723e0283753fdd4e685 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 15 Dec 2023 09:57:26 +0000 Subject: [PATCH 07/43] Remove type error --- .../erniebot_agent/agents/functional_agent_with_retrieval.py | 1 + .../examples/functional_agent_with_retrieval_example.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 91de2e162..b26b107da 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -354,6 +354,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) + # 会有无限循环调用工具的问题 next_step_input = HumanMessage( content=f"背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" ) diff --git a/erniebot-agent/examples/functional_agent_with_retrieval_example.py b/erniebot-agent/examples/functional_agent_with_retrieval_example.py index bbdd1bd59..dfcdd5c04 100644 --- a/erniebot-agent/examples/functional_agent_with_retrieval_example.py +++ b/erniebot-agent/examples/functional_agent_with_retrieval_example.py @@ -134,7 +134,7 @@ def offline_ann(data_path, baizhong_db): for query in queries: memory = WholeMemory() if args.retrieval_type == "rag": - agent = FunctionalAgentWithRetrieval( + agent = FunctionalAgentWithRetrieval( # type: ignore llm=llm, knowledge_base=baizhong_db, top_k=3, From 1f4674f9a35eedb21a3f8eadd6372c9f4cc33a07 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 15 Dec 2023 10:01:52 +0000 Subject: [PATCH 08/43] Add direct prompt --- .../agents/functional_agent_with_retrieval.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index b26b107da..33b44c71e 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -255,11 +255,6 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A await self._callback_manager.on_tool_start( agent=self, tool=self.search_tool, input_args=tool_args ) - step_input = HumanMessage( - content=self.rag_prompt.format(query=prompt, documents=results["documents"]) - ) - fake_chat_history: List[Message] = [] - fake_chat_history.append(step_input) outputs = [] for item in results["documents"]: @@ -273,6 +268,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) + # Direct Prompt next_step_input = HumanMessage(content=f"问题:{prompt},要求:请在第一步执行检索的操作,并且检索只允许调用一次") tool_resp = ToolResponse(json=tool_ret_json, files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) From 775c363715431e87baec14094b0094dab559850f Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 15 Dec 2023 11:43:14 +0000 Subject: [PATCH 09/43] Add knowledge base tools impl --- .../erniebot_agent/agents/__init__.py | 1 + .../agents/functional_agent_with_retrieval.py | 34 +++++++ .../erniebot_agent/tools/baizhong_tool.py | 7 +- .../examples/knowledge_tools_example.py | 97 +++++++++++++++++++ 4 files changed, 136 insertions(+), 3 deletions(-) create mode 100644 erniebot-agent/examples/knowledge_tools_example.py diff --git a/erniebot-agent/erniebot_agent/agents/__init__.py b/erniebot-agent/erniebot_agent/agents/__init__.py index 4fba64219..639df6020 100644 --- a/erniebot-agent/erniebot_agent/agents/__init__.py +++ b/erniebot-agent/erniebot_agent/agents/__init__.py @@ -16,6 +16,7 @@ from .functional_agent import FunctionalAgent from .functional_agent_with_retrieval import ( ContextAugmentedFunctionalAgent, + FunctionalAgentWithQueryPlanning, FunctionalAgentWithRetrieval, FunctionalAgentWithRetrievalScoreTool, FunctionalAgentWithRetrievalTool, diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 33b44c71e..9ee5a45b7 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -385,3 +385,37 @@ async def _maybe_retrieval( results = {} results["documents"] = documents return results + + +class FunctionalAgentWithQueryPlanning(FunctionalAgent): + def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs): + super().__init__(**kwargs) + self.knowledge_base = knowledge_base + self.top_k = top_k + self.threshold = threshold + self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) + self.search_tool = KnowledgeBaseTool() + + async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: + # RAG + chat_history: List[Message] = [] + actions_taken: List[AgentAction] = [] + files_involved: List[AgentFile] = [] + # 会有无限循环调用工具的问题 + # next_step_input = HumanMessage( + # content=f"请选择合适的工具来回答:{prompt},如果需要的话,可以对把问题分解成子问题,然后每个子问题选择合适的工具回答。" + # ) + next_step_input = HumanMessage(content=prompt) + num_steps_taken = 0 + while num_steps_taken < self.max_steps: + curr_step_output = await self._async_step( + next_step_input, chat_history, actions_taken, files_involved + ) + if curr_step_output is None: + response = self._create_finished_response(chat_history, actions_taken, files_involved) + self.memory.add_message(chat_history[0]) + self.memory.add_message(chat_history[-1]) + return response + num_steps_taken += 1 + response = self._create_stopped_response(chat_history, actions_taken, files_involved) + return response diff --git a/erniebot-agent/erniebot_agent/tools/baizhong_tool.py b/erniebot-agent/erniebot_agent/tools/baizhong_tool.py index a27094dcf..0a8ecfaec 100644 --- a/erniebot-agent/erniebot_agent/tools/baizhong_tool.py +++ b/erniebot-agent/erniebot_agent/tools/baizhong_tool.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Type +from typing import Any, Dict, List, Optional, Type from erniebot_agent.messages import AIMessage, HumanMessage from erniebot_agent.tools.schema import ToolParameterView @@ -30,9 +30,10 @@ class BaizhongSearchTool(Tool): ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView def __init__( - self, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None + self, name, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None ) -> None: super().__init__() + self.name = name self.db = db self.description = description self.few_shot_examples = [] @@ -44,7 +45,7 @@ def __init__( self.few_shot_examples = examples self.threshold = threshold - async def __call__(self, query: str, top_k: int = 3, filters: Optional[dict[str, Any]] = None): + async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None): documents = self.db.search(query, top_k, filters) documents = [item for item in documents if item["score"] > self.threshold] return {"documents": documents} diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py new file mode 100644 index 000000000..4c71294fc --- /dev/null +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -0,0 +1,97 @@ +import argparse +import asyncio + +from erniebot_agent.agents import FunctionalAgentWithQueryPlanning +from erniebot_agent.chat_models import ERNIEBot +from erniebot_agent.memory import WholeMemory +from erniebot_agent.retrieval import BaizhongSearch +from erniebot_agent.tools.baizhong_tool import BaizhongSearchTool +from erniebot_agent.tools.base import RemoteToolkit + +import erniebot + +parser = argparse.ArgumentParser() +parser.add_argument("--base_url", type=str, help="The Aurora serving path.") +parser.add_argument("--data_path", default="construction_regulations", type=str, help="The data path.") +parser.add_argument( + "--access_token", default="ai_studio_access_token", type=str, help="The aistudio access token." +) +parser.add_argument("--api_type", default="qianfan", type=str, help="The aistudio access token.") +parser.add_argument("--api_key", default="", type=str, help="The API Key.") +parser.add_argument("--secret_key", default="", type=str, help="The secret key.") +parser.add_argument("--indexing", action="store_true", help="The indexing step.") +parser.add_argument("--project_id", default=-1, type=int, help="The API Key.") +parser.add_argument( + "--retrieval_type", + choices=["summary_fulltext_tools", "knowledge_tools"], + default="knowledge_tools", + help="Retrieval type, default to rag.", +) +args = parser.parse_args() + +if __name__ == "__main__": + erniebot.api_type = args.api_type + erniebot.access_token = args.access_token + baizhong_db = BaizhongSearch( + base_url=args.base_url, + project_name="construct_assistant2", + remark="construction assistant test dataset", + project_id=args.project_id if args.project_id != -1 else None, + ) + print(baizhong_db.project_id) + + llm = ERNIEBot(model="ernie-bot", api_type="custom") + + # 建筑规范数据集 + retrieval_tool = BaizhongSearchTool( + name="construction_search", description="提供城市管理执法办法相关的信息", db=baizhong_db, threshold=0.1 + ) + # OpenAI数据集 + openai_tool = BaizhongSearchTool( + name="openai_search", description="提供关于OpenAI公司的信息", db=baizhong_db, threshold=0.1 + ) + # 金融数据集 + finance_tool = BaizhongSearchTool( + name="financial_search", description="提供关于量化交易相关的信息", db=baizhong_db, threshold=0.1 + ) + + summary_tool = BaizhongSearchTool( + name="text_summary_search", description="使用这个工具总结与作者生活相关的问题", db=baizhong_db, threshold=0.1 + ) + vector_tool = BaizhongSearchTool( + name="fulltext_search", description="使用这个工具检索特定的上下文,以回答有关作者生活的特定问题", db=baizhong_db, threshold=0.1 + ) + queries = [ + "量化交易", + "OpenAI管理层变更会带来哪些影响?" "城市景观照明中有过度照明的规定是什么?", + "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", + "这几篇文档主要内容是什么?", + "今天天气怎么样?", + "abcabc", + ] + toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") + for query in queries: + memory = WholeMemory() + if args.retrieval_type == "summary_fulltext": + agent = FunctionalAgentWithQueryPlanning( # type: ignore + llm=llm, + knowledge_base=baizhong_db, + top_k=3, + tools=[summary_tool, vector_tool], + memory=memory, + ) + elif args.retrieval_type == "knowledge_tools": + agent = FunctionalAgentWithQueryPlanning( # type: ignore + llm=llm, + knowledge_base=baizhong_db, + top_k=3, + tools=toolkit.get_tools() + [retrieval_tool, openai_tool, finance_tool], + memory=memory, + ) + + try: + response = asyncio.run(agent.async_run(query)) + print(f"query: {query}") + print(f"agent response: {response}") + except Exception as e: + print(e) From f028c841bf249dfed9d144e50e7713f79e4ca2cd Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 15 Dec 2023 12:02:08 +0000 Subject: [PATCH 10/43] Add retriever config --- .../examples/knowledge_tools_example.py | 4 + erniebot-agent/examples/text_summarization.py | 61 +++++++++++++ erniebot-agent/examples/utils.py | 89 +++++++++++++++++++ 3 files changed, 154 insertions(+) create mode 100644 erniebot-agent/examples/text_summarization.py create mode 100644 erniebot-agent/examples/utils.py diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py index 4c71294fc..2e587732c 100644 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -61,6 +61,9 @@ vector_tool = BaizhongSearchTool( name="fulltext_search", description="使用这个工具检索特定的上下文,以回答有关作者生活的特定问题", db=baizhong_db, threshold=0.1 ) + tool_retriever = BaizhongSearchTool( + name="tool_retriever", description="用于检索与query相关的tools列表", db=baizhong_db, threshold=0.1 + ) queries = [ "量化交易", "OpenAI管理层变更会带来哪些影响?" "城市景观照明中有过度照明的规定是什么?", @@ -81,6 +84,7 @@ memory=memory, ) elif args.retrieval_type == "knowledge_tools": + # TODO(wugaosheng) Add knowledge base tool retriever for tool selection agent = FunctionalAgentWithQueryPlanning( # type: ignore llm=llm, knowledge_base=baizhong_db, diff --git a/erniebot-agent/examples/text_summarization.py b/erniebot-agent/examples/text_summarization.py new file mode 100644 index 000000000..351741dc8 --- /dev/null +++ b/erniebot-agent/examples/text_summarization.py @@ -0,0 +1,61 @@ +import argparse +import os + +import jsonlines +from utils import create_abstract, erniebot_chat, read_data, split_text + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--api_type", default=None, type=str, help="The API Key.") +parser.add_argument("--access_token", default=None, type=str, help="The secret key.") +parser.add_argument("--data_path", default='data/json_data.jsonl', type=str, help="The data path.") +parser.add_argument("--output_path", default='data/finance_abstract', type=str, help="The output path.") +parser.add_argument('--chatbot_type', choices=['erniebot'], default="erniebot", + help="The chatbot model types") +args = parser.parse_args() +# yapf: enable + + +def summarize_text(text: str): + if not text: + return "Error: No text to summarize" + summaries = [] + + chunks = list(split_text(text, max_length=4096)) + scroll_ratio = 1 / len(chunks) + print(scroll_ratio) + print(f"Summarizing text with total chunks: {len(chunks)}") + for i, chunk in enumerate(chunks): + messages = [create_abstract(chunk)] + summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token) + print(summary) + summaries.append(summary) + + combined_summary = "\n".join(summaries) + combined_summary = combined_summary[:7000] + messages = [create_abstract(combined_summary)] + + final_summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token) + print("Final summary length: ", len(final_summary)) + print(final_summary) + return final_summary + + +def generate_summary_jsonl(): + os.makedirs(args.output_path, exist_ok=True) + list_data = read_data(args.data_path) + for md_file in list_data: + markdown_text = md_file["content"] + summary = summarize_text(markdown_text) + md_file["abstract"] = summary + + output_json = f"{args.output_path}/data.jsonl" + with jsonlines.open(output_json, "w") as f: + for item in list_data: + f.write(item) + return output_json + + +if __name__ == "__main__": + # text summarization + generate_summary_jsonl() diff --git a/erniebot-agent/examples/utils.py b/erniebot-agent/examples/utils.py new file mode 100644 index 000000000..6a5600631 --- /dev/null +++ b/erniebot-agent/examples/utils.py @@ -0,0 +1,89 @@ +from typing import Dict, Generator + +import jsonlines + +import erniebot + + +def read_data(json_path): + list_data = [] + with jsonlines.open(json_path, "r") as f: + for item in f: + list_data.append(item) + return list_data + + +def create_abstract(chunk: str) -> Dict[str, str]: + """Create a message for the chat completion + + Args: + chunk (str): The chunk of text to summarize + question (str): The question to answer + + Returns: + Dict[str, str]: The message to send to the chat completion + """ + return { + "role": "user", + "content": f"""{chunk},请用中文对上述文章进行总结,总结需要有概括性,不允许输出与文章内容无关的信息,字数控制在500字以内。""", + } + + +def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]: + """Split text into chunks of a maximum length + + Args: + text (str): The text to split + max_length (int, optional): The maximum length of each chunk. Defaults to 8192. + + Yields: + str: The next chunk of text + + Raises: + ValueError: If the text is longer than the maximum length + """ + paragraphs = text.split("\n") + current_length = 0 + current_chunk = [] + + for paragraph in paragraphs: + if current_length + len(paragraph) + 1 <= max_length: + current_chunk.append(paragraph) + current_length += len(paragraph) + 1 + else: + yield "\n".join(current_chunk) + current_chunk = [paragraph] + current_length = len(paragraph) + 1 + + if current_chunk: + yield "\n".join(current_chunk) + + +def erniebot_chat( + messages, model="ernie-bot-8k", api_type="aistudio", access_token=None, functions=None, **kwargs +): + """ + Args: + messages: dict or list, 输入的消息(message) + model: str, 模型名称 + api_type: str, 接口类型,可选值包括 'aistudio' 和 'webchat' + access_token: str, 访问令牌(access token) + functions: list, 函数列表 + kwargs: 其他参数 + + Returns: + dict or list, 返回聊天结果 + """ + _config = dict( + api_type=api_type, + access_token=access_token, + ) + if functions is None: + resp_stream = erniebot.ChatCompletion.create( + _config_=_config, model=model, messages=messages, **kwargs + ) + else: + resp_stream = erniebot.ChatCompletion.create( + _config_=_config, model=model, messages=messages, **kwargs, functions=functions + ) + return resp_stream From b3b1a99a3bee28a728ecbaa99915899f4cd2e5ed Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 18 Dec 2023 04:34:23 +0000 Subject: [PATCH 11/43] Update tool retrieval --- .../agents/functional_agent_with_retrieval.py | 5 +- .../tools/openai_search_tool.py | 76 ++++++++ .../examples/knowledge_tools_example.py | 163 ++++++++++++++---- 3 files changed, 207 insertions(+), 37 deletions(-) create mode 100644 erniebot-agent/erniebot_agent/tools/openai_search_tool.py diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 9ee5a45b7..558293bc0 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -388,13 +388,12 @@ async def _maybe_retrieval( class FunctionalAgentWithQueryPlanning(FunctionalAgent): - def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs): + def __init__(self, top_k: int = 3, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) - self.knowledge_base = knowledge_base self.top_k = top_k self.threshold = threshold self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) - self.search_tool = KnowledgeBaseTool() + # self.search_tool = KnowledgeBaseTool() async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: # RAG diff --git a/erniebot-agent/erniebot_agent/tools/openai_search_tool.py b/erniebot-agent/erniebot_agent/tools/openai_search_tool.py new file mode 100644 index 000000000..240c2e153 --- /dev/null +++ b/erniebot-agent/erniebot_agent/tools/openai_search_tool.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Type + +from erniebot_agent.messages import AIMessage, HumanMessage +from erniebot_agent.tools.schema import ToolParameterView +from pydantic import Field + +from .base import Tool + + +class OpenAISearchToolInputView(ToolParameterView): + query: str = Field(description="查询语句") + top_k: int = Field(description="返回结果数量") + + +class SearchResponseDocument(ToolParameterView): + title: str = Field(description="检索结果的标题") + document: str = Field(description="检索结果的内容") + + +class OpenAISearchToolOutputView(ToolParameterView): + documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落") + + +class OpenAISearchTool(Tool): + description: str = "在知识库中检索与用户输入query相关的段落" + input_type: Type[ToolParameterView] = OpenAISearchToolInputView + ouptut_type: Type[ToolParameterView] = OpenAISearchToolOutputView + + def __init__( + self, name, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None + ) -> None: + super().__init__() + self.name = name + self.db = db + self.description = description + self.few_shot_examples = [] + if input_type is not None: + self.input_type = input_type + if output_type is not None: + self.ouptut_type = output_type + if examples is not None: + self.few_shot_examples = examples + self.threshold = threshold + + async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None): + documents = self.db.similarity_search_with_relevance_scores(query, top_k) + docs = [] + for doc, score in documents: + if score > self.threshold: + docs.append( + {"document": doc.page_content, "title": doc.metadata["source"], "meta": doc.metadata} + ) + + return {"documents": docs} + + @property + def examples( + self, + ) -> List[Any]: + few_shot_objects: List[Any] = [] + for item in self.few_shot_examples: + few_shot_objects.append(HumanMessage(item["user"])) + few_shot_objects.append( + AIMessage( + "", + function_call={ + "name": self.tool_name, + "thoughts": item["thoughts"], + "arguments": item["arguments"], + }, + ) + ) + + return few_shot_objects diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py index 2e587732c..d765d9704 100644 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -1,12 +1,20 @@ import argparse import asyncio +from typing import Dict, List, Type from erniebot_agent.agents import FunctionalAgentWithQueryPlanning from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.memory import WholeMemory +from erniebot_agent.messages import AIMessage, HumanMessage, Message from erniebot_agent.retrieval import BaizhongSearch from erniebot_agent.tools.baizhong_tool import BaizhongSearchTool -from erniebot_agent.tools.base import RemoteToolkit +from erniebot_agent.tools.base import RemoteToolkit, Tool +from erniebot_agent.tools.openai_search_tool import OpenAISearchTool +from erniebot_agent.tools.schema import ToolParameterView +from langchain.docstore.document import Document +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.vectorstores import FAISS +from pydantic import Field import erniebot @@ -27,47 +35,126 @@ default="knowledge_tools", help="Retrieval type, default to rag.", ) +parser.add_argument( + "--search_engine", + choices=["baizhong", "openai"], + default="baizhong", + help="search_engine.", +) args = parser.parse_args() + +class NotesToolInputView(ToolParameterView): + draft: str = Field(description="草稿文本") + + +class NotesToolOutputView(ToolParameterView): + draft_results: str = Field(description="草稿文本结果") + + +class NotesTool(Tool): + description: str = "笔记本,用于记录和保存信息的笔记本工具" + input_type: Type[ToolParameterView] = NotesToolInputView + ouptut_type: Type[ToolParameterView] = NotesToolOutputView + + async def __call__(self, draft: str) -> Dict[str, str]: + # TODO: save draft to database + return {"draft_results": "草稿在笔记本中保存成功"} + + @property + def examples(self) -> List[Message]: + return [ + HumanMessage("OpenAI管理层变更会带来哪些影响?并请把搜索的内容添加到笔记本中"), + AIMessage( + "", + function_call={ + "name": self.tool_name, + "thoughts": f"用户想保存笔记,我可以使用{self.tool_name}工具来保存,其中`draft`字段的内容为:'搜索的草稿'。", + "arguments": '{"draft": "搜索的草稿"}', + }, + ), + ] + + if __name__ == "__main__": erniebot.api_type = args.api_type erniebot.access_token = args.access_token - baizhong_db = BaizhongSearch( - base_url=args.base_url, - project_name="construct_assistant2", - remark="construction assistant test dataset", - project_id=args.project_id if args.project_id != -1 else None, - ) - print(baizhong_db.project_id) llm = ERNIEBot(model="ernie-bot", api_type="custom") + if args.search_engine == "baizhong": + baizhong_db = BaizhongSearch( + base_url=args.base_url, + project_name="construct_assistant2", + remark="construction assistant test dataset", + project_id=args.project_id if args.project_id != -1 else None, + ) + print(baizhong_db.project_id) + # 建筑规范数据集 + city_management = BaizhongSearchTool( + name="city_administrative_law_enforcement", + description="提供城市管理执法办法相关的信息", + db=baizhong_db, + threshold=0.1, + ) + city_design = BaizhongSearchTool( + name="city_design_management", description="提供城市设计管理办法的信息", db=baizhong_db, threshold=0.1 + ) + city_lighting = BaizhongSearchTool( + name="city_lighting", description="提供关于城市照明管理规定的信息", db=baizhong_db, threshold=0.1 + ) + + summary_tool = BaizhongSearchTool( + name="text_summary_search", description="使用这个工具总结与作者生活相关的问题", db=baizhong_db, threshold=0.1 + ) + vector_tool = BaizhongSearchTool( + name="fulltext_search", + description="使用这个工具检索特定的上下文,以回答有关作者生活的特定问题", + db=baizhong_db, + threshold=0.1, + ) + tool_retriever = BaizhongSearchTool( + name="tool_retriever", description="用于检索与query相关的tools列表", db=baizhong_db, threshold=0.1 + ) + elif args.search_engine == "openai": + embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") + faiss = FAISS.load_local("城市管理执法办法", embeddings) + openai_city_management = OpenAISearchTool( + name="city_administrative_law_enforcement", + description="提供城市管理执法办法相关的信息", + db=faiss, + threshold=0.1, + ) + faiss = FAISS.load_local("城市设计管理办法", embeddings) + openai_city_design = OpenAISearchTool( + name="city_design_management", description="提供城市设计管理办法的信息", db=faiss, threshold=0.1 + ) + faiss = FAISS.load_local("城市照明管理规定", embeddings) + openai_city_lighting = OpenAISearchTool( + name="city_lighting", description="提供关于城市照明管理规定的信息", db=faiss, threshold=0.1 + ) + # TODO(wugaoshewng) 加入APE后,变成knowledge_base_toolkit + # faiss = FAISS.load_local("tool_retriever", embeddings) + tool_map = { + "city_administrative_law_enforcement": openai_city_management, + "city_design_management": openai_city_design, + "city_lighting": openai_city_lighting, + } + docs = [] + for tool in tool_map.values(): + doc = Document(page_content=tool.description, metadata={"tool_name": tool.name}) + docs.append(doc) + + faiss_tool = FAISS.from_documents(docs, embeddings) + tool_retriever = OpenAISearchTool( # type: ignore + name="tool_retriever", description="用于检索与query相关的tools列表", db=faiss_tool, threshold=0.1 + ) - # 建筑规范数据集 - retrieval_tool = BaizhongSearchTool( - name="construction_search", description="提供城市管理执法办法相关的信息", db=baizhong_db, threshold=0.1 - ) - # OpenAI数据集 - openai_tool = BaizhongSearchTool( - name="openai_search", description="提供关于OpenAI公司的信息", db=baizhong_db, threshold=0.1 - ) - # 金融数据集 - finance_tool = BaizhongSearchTool( - name="financial_search", description="提供关于量化交易相关的信息", db=baizhong_db, threshold=0.1 - ) - - summary_tool = BaizhongSearchTool( - name="text_summary_search", description="使用这个工具总结与作者生活相关的问题", db=baizhong_db, threshold=0.1 - ) - vector_tool = BaizhongSearchTool( - name="fulltext_search", description="使用这个工具检索特定的上下文,以回答有关作者生活的特定问题", db=baizhong_db, threshold=0.1 - ) - tool_retriever = BaizhongSearchTool( - name="tool_retriever", description="用于检索与query相关的tools列表", db=baizhong_db, threshold=0.1 - ) queries = [ "量化交易", - "OpenAI管理层变更会带来哪些影响?" "城市景观照明中有过度照明的规定是什么?", + "OpenAI管理层变更会带来哪些影响?", + "城市景观照明中有过度照明的规定是什么?", "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", + "请比较一下城市设计管理和照明管理规定的区别?", "这几篇文档主要内容是什么?", "今天天气怎么样?", "abcabc", @@ -78,18 +165,26 @@ if args.retrieval_type == "summary_fulltext": agent = FunctionalAgentWithQueryPlanning( # type: ignore llm=llm, - knowledge_base=baizhong_db, top_k=3, tools=[summary_tool, vector_tool], memory=memory, ) elif args.retrieval_type == "knowledge_tools": # TODO(wugaosheng) Add knowledge base tool retriever for tool selection + # tool_results: Dict = tool_retriever(query)["documents"] + tool_results = asyncio.run(tool_retriever(query))["documents"] + selected_tools = [] + for item in tool_results: + tool_name = tool_map[item["meta"]["tool_name"]] + selected_tools.append(tool_name) + + # selected_tools = [tool_map[item['meta']["tool_name"]] for item in tool_results] agent = FunctionalAgentWithQueryPlanning( # type: ignore llm=llm, - knowledge_base=baizhong_db, top_k=3, - tools=toolkit.get_tools() + [retrieval_tool, openai_tool, finance_tool], + # tools=toolkit.get_tools() + [city_management, city_design, city_lighting], + # tools = [NotesTool(),city_management, city_design, city_lighting], + tools=[NotesTool] + selected_tools, memory=memory, ) From 67066dabf477cee2c4d7820bb91a362beddba454 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 18 Dec 2023 04:35:13 +0000 Subject: [PATCH 12/43] Update tool retrieval --- erniebot-agent/examples/knowledge_tools_example.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py index d765d9704..415755544 100644 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -171,14 +171,8 @@ def examples(self) -> List[Message]: ) elif args.retrieval_type == "knowledge_tools": # TODO(wugaosheng) Add knowledge base tool retriever for tool selection - # tool_results: Dict = tool_retriever(query)["documents"] tool_results = asyncio.run(tool_retriever(query))["documents"] - selected_tools = [] - for item in tool_results: - tool_name = tool_map[item["meta"]["tool_name"]] - selected_tools.append(tool_name) - - # selected_tools = [tool_map[item['meta']["tool_name"]] for item in tool_results] + selected_tools = [tool_map[item["meta"]["tool_name"]] for item in tool_results] agent = FunctionalAgentWithQueryPlanning( # type: ignore llm=llm, top_k=3, From 51c0137414dc4f89182add85bab61f5b5f8f2519 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 19 Dec 2023 00:41:43 +0000 Subject: [PATCH 13/43] Add system prompt --- .../agents/functional_agent_with_retrieval.py | 12 +++++++---- .../tools/openai_search_tool.py | 21 +++++++++++++++---- .../examples/knowledge_tools_example.py | 9 +++++--- erniebot-agent/examples/text_summarization.py | 5 ++++- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 558293bc0..3e756f28a 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -9,7 +9,13 @@ ToolResponse, ) from erniebot_agent.file_io.base import File -from erniebot_agent.messages import AIMessage, FunctionMessage, HumanMessage, Message +from erniebot_agent.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + Message, + SystemMessage, +) from erniebot_agent.prompt import PromptTemplate from erniebot_agent.retrieval import BaizhongSearch from erniebot_agent.tools.base import Tool @@ -392,11 +398,9 @@ def __init__(self, top_k: int = 3, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) self.top_k = top_k self.threshold = threshold - self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) - # self.search_tool = KnowledgeBaseTool() + self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。") async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: - # RAG chat_history: List[Message] = [] actions_taken: List[AgentAction] = [] files_involved: List[AgentFile] = [] diff --git a/erniebot-agent/erniebot_agent/tools/openai_search_tool.py b/erniebot-agent/erniebot_agent/tools/openai_search_tool.py index 240c2e153..2beb06296 100644 --- a/erniebot-agent/erniebot_agent/tools/openai_search_tool.py +++ b/erniebot-agent/erniebot_agent/tools/openai_search_tool.py @@ -29,12 +29,21 @@ class OpenAISearchTool(Tool): ouptut_type: Type[ToolParameterView] = OpenAISearchToolOutputView def __init__( - self, name, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None + self, + name, + description, + db, + threshold: float = 0.0, + input_type=None, + output_type=None, + examples=None, + return_meta_data: bool = False, ) -> None: super().__init__() self.name = name self.db = db self.description = description + self.return_meta_data = return_meta_data self.few_shot_examples = [] if input_type is not None: self.input_type = input_type @@ -49,9 +58,13 @@ async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, docs = [] for doc, score in documents: if score > self.threshold: - docs.append( - {"document": doc.page_content, "title": doc.metadata["source"], "meta": doc.metadata} - ) + new_doc = {"document": doc.page_content} + if self.return_meta_data: + new_doc["meta"] = doc.metadata + if "source" in doc.metadata: + new_doc["title"] = doc.metadata["source"] + + docs.append(new_doc) return {"documents": docs} diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py index 415755544..6061eb0fe 100644 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -146,7 +146,11 @@ def examples(self) -> List[Message]: faiss_tool = FAISS.from_documents(docs, embeddings) tool_retriever = OpenAISearchTool( # type: ignore - name="tool_retriever", description="用于检索与query相关的tools列表", db=faiss_tool, threshold=0.1 + name="tool_retriever", + description="用于检索与query相关的tools列表", + db=faiss_tool, + threshold=0.1, + return_meta_data=True, ) queries = [ @@ -170,7 +174,6 @@ def examples(self) -> List[Message]: memory=memory, ) elif args.retrieval_type == "knowledge_tools": - # TODO(wugaosheng) Add knowledge base tool retriever for tool selection tool_results = asyncio.run(tool_retriever(query))["documents"] selected_tools = [tool_map[item["meta"]["tool_name"]] for item in tool_results] agent = FunctionalAgentWithQueryPlanning( # type: ignore @@ -178,7 +181,7 @@ def examples(self) -> List[Message]: top_k=3, # tools=toolkit.get_tools() + [city_management, city_design, city_lighting], # tools = [NotesTool(),city_management, city_design, city_lighting], - tools=[NotesTool] + selected_tools, + tools=[NotesTool()] + selected_tools, memory=memory, ) diff --git a/erniebot-agent/examples/text_summarization.py b/erniebot-agent/examples/text_summarization.py index 351741dc8..175758d61 100644 --- a/erniebot-agent/examples/text_summarization.py +++ b/erniebot-agent/examples/text_summarization.py @@ -27,10 +27,13 @@ def summarize_text(text: str): print(f"Summarizing text with total chunks: {len(chunks)}") for i, chunk in enumerate(chunks): messages = [create_abstract(chunk)] - summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token) + summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token).rbody[ + "result" + ] print(summary) summaries.append(summary) + # breakpoint() combined_summary = "\n".join(summaries) combined_summary = combined_summary[:7000] messages = [create_abstract(combined_summary)] From 3320e7b508a9f0193e966b525acdfeb482bc3a26 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 19 Dec 2023 02:47:05 +0000 Subject: [PATCH 14/43] Add faiss indexing --- erniebot-agent/examples/faiss_util.py | 98 +++++++++++++++++++ .../examples/knowledge_tools_example.py | 26 ++--- erniebot-agent/examples/text_summarization.py | 9 +- 3 files changed, 116 insertions(+), 17 deletions(-) create mode 100644 erniebot-agent/examples/faiss_util.py diff --git a/erniebot-agent/examples/faiss_util.py b/erniebot-agent/examples/faiss_util.py new file mode 100644 index 000000000..28f0e40d2 --- /dev/null +++ b/erniebot-agent/examples/faiss_util.py @@ -0,0 +1,98 @@ +import argparse +import os + +from langchain.docstore.document import Document +from langchain.document_loaders import UnstructuredFileLoader +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.text_splitter import SpacyTextSplitter +from langchain.vectorstores import FAISS +from utils import read_data + +embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") + + +def get_args(): + # yapf: disable + parser = argparse.ArgumentParser() + parser.add_argument('--faiss_name', default="faiss_index", help="The faiss index") + parser.add_argument('--summary_name', default="faiss_index", help="The summary text index") + parser.add_argument('--fulltext_name', default="faiss_index", help="The full text index") + parser.add_argument('--file_path', default="data/output.jsonl", help="The data output path") + parser.add_argument('--indexing_type', choices=['summary_fulltext', 'common'], + default="common", help="The indexing types") + args = parser.parse_args() + # yapf: enable + return args + + +class SemanticSearch: + def __init__(self, faiss_name, file_path=None) -> None: + self.faiss_name = faiss_name + self.file_path = file_path + self.vector_db = self.init_db() + + def init_db( + self, + ): + if os.path.exists(self.faiss_name): + faiss = FAISS.load_local(self.faiss_name, embeddings) + else: + loader = UnstructuredFileLoader(self.file_path) + documents = loader.load() + text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) + docs = text_splitter.split_documents(documents) + faiss = FAISS.from_documents(docs, embeddings) + faiss.save_local(self.faiss_name) + return faiss + + def search(self, query, top_k=4): + return self.vector_db.similarity_search(query, k=top_k) + + +class RecursiveDocuments: + def __init__(self, summary_name, fulltext_name, file_path=None) -> None: + self.summary_name = summary_name + self.fulltext_name = fulltext_name + self.file_path = file_path + self.vector_db = self.init_db() + + def init_db( + self, + ): + if os.path.exists(self.summary_name) and os.path.exists(self.fulltext_name): + summary_faiss = FAISS.load_local(self.summary_name, embeddings) + fulltext_faiss = FAISS.load_local(self.fulltext_name, embeddings) + else: + list_data = read_data(self.file_path) + doc_summary = [] + doc_fulltext = [] + text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) + for item in list_data: + full_texts = Document(page_content=item["content"]) + + abstract = Document(page_content=item["abstract"]) + docs = text_splitter.split_documents([full_texts]) + doc_fulltext.extend(docs) + + doc_summary.append(abstract) + + summary_faiss = FAISS.from_documents(doc_summary, embeddings) + summary_faiss.save_local(self.summary_name) + + fulltext_faiss = FAISS.from_documents(doc_fulltext, embeddings) + fulltext_faiss.save_local(self.fulltext_name) + return summary_faiss + + def search(self, query, top_k=4): + return self.vector_db.similarity_search(query, k=top_k) + + +if __name__ == "__main__": + query = "GPT-3是怎么训练得到的?" + args = get_args() + if args.indexing_type == "common": + faiss_search = SemanticSearch(args.faiss_name, args.file_path) + docs = faiss_search.search(query) + print(docs) + else: + recursive_search = RecursiveDocuments(args.summary_name, args.fulltext_name, args.file_path) diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py index 6061eb0fe..f60746cd7 100644 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -103,19 +103,10 @@ def examples(self) -> List[Message]: name="city_lighting", description="提供关于城市照明管理规定的信息", db=baizhong_db, threshold=0.1 ) - summary_tool = BaizhongSearchTool( - name="text_summary_search", description="使用这个工具总结与作者生活相关的问题", db=baizhong_db, threshold=0.1 - ) - vector_tool = BaizhongSearchTool( - name="fulltext_search", - description="使用这个工具检索特定的上下文,以回答有关作者生活的特定问题", - db=baizhong_db, - threshold=0.1, - ) tool_retriever = BaizhongSearchTool( name="tool_retriever", description="用于检索与query相关的tools列表", db=baizhong_db, threshold=0.1 ) - elif args.search_engine == "openai": + elif args.search_engine == "openai" and args.retrieval_type == "knowledge_tools": embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") faiss = FAISS.load_local("城市管理执法办法", embeddings) openai_city_management = OpenAISearchTool( @@ -152,6 +143,19 @@ def examples(self) -> List[Message]: threshold=0.1, return_meta_data=True, ) + elif args.search_engine == "openai" and args.retrieval_type == "summary_fulltext_tools": + embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") + summary_faiss = FAISS.load_local("summary", embeddings) + summary_tool = OpenAISearchTool( + name="text_summary_search", description="使用这个工具总结与建筑规范相关的问题", db=summary_faiss, threshold=0.1 + ) + fulltext_faiss = FAISS.load_local("fulltext", embeddings) + vector_tool = OpenAISearchTool( + name="fulltext_search", + description="使用这个工具检索特定的上下文,以回答有关建筑规范具体的问题", + db=fulltext_faiss, + threshold=0.1, + ) queries = [ "量化交易", @@ -166,7 +170,7 @@ def examples(self) -> List[Message]: toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") for query in queries: memory = WholeMemory() - if args.retrieval_type == "summary_fulltext": + if args.retrieval_type == "summary_fulltext_tools": agent = FunctionalAgentWithQueryPlanning( # type: ignore llm=llm, top_k=3, diff --git a/erniebot-agent/examples/text_summarization.py b/erniebot-agent/examples/text_summarization.py index 175758d61..9f4024707 100644 --- a/erniebot-agent/examples/text_summarization.py +++ b/erniebot-agent/examples/text_summarization.py @@ -17,13 +17,9 @@ def summarize_text(text: str): - if not text: - return "Error: No text to summarize" summaries = [] chunks = list(split_text(text, max_length=4096)) - scroll_ratio = 1 / len(chunks) - print(scroll_ratio) print(f"Summarizing text with total chunks: {len(chunks)}") for i, chunk in enumerate(chunks): messages = [create_abstract(chunk)] @@ -33,12 +29,13 @@ def summarize_text(text: str): print(summary) summaries.append(summary) - # breakpoint() combined_summary = "\n".join(summaries) combined_summary = combined_summary[:7000] messages = [create_abstract(combined_summary)] - final_summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token) + final_summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token).rbody[ + "result" + ] print("Final summary length: ", len(final_summary)) print(final_summary) return final_summary From 76604f800e25acd8b39a2a2146d2b9b203798608 Mon Sep 17 00:00:00 2001 From: w5688414 <15623211472@163.com> Date: Tue, 19 Dec 2023 10:50:50 +0800 Subject: [PATCH 15/43] Update erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py Co-authored-by: Sijun He --- .../erniebot_agent/agents/functional_agent_with_retrieval.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 3e756f28a..50f7a8de4 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -155,8 +155,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A step_input = HumanMessage( content=self.rag_prompt.format(query=prompt, documents=results["documents"]) ) - fake_chat_history: List[Message] = [] - fake_chat_history.append(step_input) + fake_chat_history: List[Message] = [step_input] llm_resp = await self._async_run_llm_without_hooks( messages=fake_chat_history, functions=None, From e28cf0bd950144d3eb9d448e646663633b315641 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 19 Dec 2023 13:54:18 +0000 Subject: [PATCH 16/43] Add planning and execute rules --- .../agents/functional_agent_with_retrieval.py | 67 ++++++++++++++++++- .../examples/knowledge_tools_example.py | 22 ++++-- src/erniebot/resources/chat_completion.py | 1 - 3 files changed, 79 insertions(+), 11 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 50f7a8de4..684f688e9 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -392,17 +392,27 @@ async def _maybe_retrieval( return results +QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,每个子问题必须足够简单,要求: +1.严格按照【JSON格式】的形式输出:{'子问题1':'具体子问题1','子问题2':'具体子问题2'} +问题:{{prompt}} 子问题:""" + + class FunctionalAgentWithQueryPlanning(FunctionalAgent): - def __init__(self, top_k: int = 3, threshold: float = 0.1, **kwargs): + def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) self.top_k = top_k self.threshold = threshold self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。") + self.query_transform = PromptTemplate(QUERY_DECOMPOSITION, input_variables=["prompt"]) + self.knowledge_base = knowledge_base + self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: - chat_history: List[Message] = [] + # chat_history: List[Message] = [] actions_taken: List[AgentAction] = [] files_involved: List[AgentFile] = [] + chat_history: List[Message] = [] + # 会有无限循环调用工具的问题 # next_step_input = HumanMessage( # content=f"请选择合适的工具来回答:{prompt},如果需要的话,可以对把问题分解成子问题,然后每个子问题选择合适的工具回答。" @@ -419,5 +429,56 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A self.memory.add_message(chat_history[-1]) return response num_steps_taken += 1 - response = self._create_stopped_response(chat_history, actions_taken, files_involved) + # TODO(wugaosheng): Add manual planning and execute + # response = self._create_stopped_response(chat_history, actions_taken, files_involved) + return await self.plan_and_execute(prompt, actions_taken, files_involved) + + async def plan_and_execute(self, prompt, actions_taken, files_involved): + step_input = HumanMessage(content=self.query_transform.format(prompt=prompt)) + fake_chat_history: List[Message] = [step_input] + llm_resp = await self._async_run_llm_without_hooks( + messages=fake_chat_history, + functions=None, + system=self.system_message.content if self.system_message is not None else None, + ) + output_message = llm_resp.message + + json_results = self._parse_results(output_message.content) + sub_queries = json_results.values() + retrieval_results = [] + duplicates = set() + for query in sub_queries: + documents = await self.knowledge_base(query, top_k=self.top_k, filters=None) + docs = [item for item in documents["documents"]] + for doc in docs: + if doc["document"] not in duplicates: + duplicates.add(doc["document"]) + retrieval_results.append(doc) + step_input = HumanMessage( + content=self.rag_prompt.format(query=prompt, documents=retrieval_results[:3]) + ) + chat_history: List[Message] = [step_input] + llm_resp = await self._async_run_llm_without_hooks( + messages=chat_history, + functions=None, + system=self.system_message.content if self.system_message is not None else None, + ) + + output_message = llm_resp.message + chat_history.append(output_message) + response = self._create_finished_response(chat_history, actions_taken, files_involved) + self.memory.add_message(chat_history[0]) + self.memory.add_message(chat_history[-1]) return response + + def _parse_results(self, results): + left_index = results.find("{") + right_index = results.rfind("}") + if left_index == -1 or right_index == -1: + # if invalid json, use Functional Agent + return {"is_relevant": False} + try: + return json.loads(results[left_index : right_index + 1]) + except Exception: + # if invalid json, use Functional Agent + return {"is_relevant": False} diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py index f60746cd7..0110d2fa9 100644 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -143,6 +143,13 @@ def examples(self) -> List[Message]: threshold=0.1, return_meta_data=True, ) + fulltext_faiss = FAISS.load_local("fulltext", embeddings) + vector_tool = OpenAISearchTool( + name="fulltext_search", + description="使用这个工具检索特定的上下文,以回答有关建筑规范具体的问题", + db=fulltext_faiss, + threshold=0.1, + ) elif args.search_engine == "openai" and args.retrieval_type == "summary_fulltext_tools": embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") summary_faiss = FAISS.load_local("summary", embeddings) @@ -158,14 +165,14 @@ def examples(self) -> List[Message]: ) queries = [ - "量化交易", - "OpenAI管理层变更会带来哪些影响?", - "城市景观照明中有过度照明的规定是什么?", - "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", + # "量化交易", + # "OpenAI管理层变更会带来哪些影响?", + # "城市景观照明中有过度照明的规定是什么?", + # "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", "请比较一下城市设计管理和照明管理规定的区别?", - "这几篇文档主要内容是什么?", - "今天天气怎么样?", - "abcabc", + # "这几篇文档主要内容是什么?", + # "今天天气怎么样?", + # "abcabc", ] toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") for query in queries: @@ -183,6 +190,7 @@ def examples(self) -> List[Message]: agent = FunctionalAgentWithQueryPlanning( # type: ignore llm=llm, top_k=3, + knowledge_base=vector_tool, # tools=toolkit.get_tools() + [city_management, city_design, city_lighting], # tools = [NotesTool(),city_management, city_design, city_lighting], tools=[NotesTool()] + selected_tools, diff --git a/src/erniebot/resources/chat_completion.py b/src/erniebot/resources/chat_completion.py index c38577912..df91fe5e6 100644 --- a/src/erniebot/resources/chat_completion.py +++ b/src/erniebot/resources/chat_completion.py @@ -493,7 +493,6 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: # stream stream = kwargs.get("stream", False) - return RequestWithStream( path=path, params=params, From a6d331a7c2f089c68019279076bbb5a325d045a8 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 20 Dec 2023 02:55:37 +0000 Subject: [PATCH 17/43] Update planning logic --- .../agents/functional_agent_with_retrieval.py | 10 +++++++++- erniebot-agent/examples/knowledge_tools_example.py | 14 +++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py index 684f688e9..60b2b1c74 100644 --- a/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -397,6 +397,14 @@ async def _maybe_retrieval( 问题:{{prompt}} 子问题:""" +OPENAI_RAG_PROMPT = """检索结果: +{% for doc in documents %} + 第{{loop.index}}个段落: {{doc['document']}} +{% endfor %} +检索语句: {{query}} +请根据以上检索结果回答检索语句的问题""" + + class FunctionalAgentWithQueryPlanning(FunctionalAgent): def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) @@ -405,7 +413,7 @@ def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwa self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。") self.query_transform = PromptTemplate(QUERY_DECOMPOSITION, input_variables=["prompt"]) self.knowledge_base = knowledge_base - self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) + self.rag_prompt = PromptTemplate(OPENAI_RAG_PROMPT, input_variables=["documents", "query"]) async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: # chat_history: List[Message] = [] diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py index 0110d2fa9..d6e90f7f3 100644 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ b/erniebot-agent/examples/knowledge_tools_example.py @@ -165,14 +165,14 @@ def examples(self) -> List[Message]: ) queries = [ - # "量化交易", - # "OpenAI管理层变更会带来哪些影响?", - # "城市景观照明中有过度照明的规定是什么?", - # "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", + "量化交易", + "OpenAI管理层变更会带来哪些影响?", + "城市景观照明中有过度照明的规定是什么?", + "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", "请比较一下城市设计管理和照明管理规定的区别?", - # "这几篇文档主要内容是什么?", - # "今天天气怎么样?", - # "abcabc", + "这几篇文档主要内容是什么?", + "今天天气怎么样?", + "abcabc", ] toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") for query in queries: From a78fa660f4a309ec2fe37a27f49c8d3d371d7eb8 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 20 Dec 2023 11:58:04 +0000 Subject: [PATCH 18/43] Add automatic prompt engineer --- .../examples/automatic_prompt_engineer.py | 103 ++++++++++++++++++ erniebot-agent/examples/utils.py | 36 +++++- 2 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 erniebot-agent/examples/automatic_prompt_engineer.py diff --git a/erniebot-agent/examples/automatic_prompt_engineer.py b/erniebot-agent/examples/automatic_prompt_engineer.py new file mode 100644 index 000000000..7105c62bf --- /dev/null +++ b/erniebot-agent/examples/automatic_prompt_engineer.py @@ -0,0 +1,103 @@ +import argparse +import asyncio + +from erniebot_agent.agents.prompt_agent import PromptAgent +from erniebot_agent.chat_models import ERNIEBot +from erniebot_agent.memory import WholeMemory +from erniebot_agent.tools.openai_search_tool import OpenAISearchTool +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.vectorstores import FAISS +from prettytable import PrettyTable +from utils import create_description, create_questions, erniebot_chat, read_data + +import erniebot + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--api_type", default=None, type=str, help="The API Key.") +parser.add_argument("--access_token", default=None, type=str, help="The secret key.") +parser.add_argument("--summarization_path", default='data/data.jsonl', type=str, help="The output path.") +parser.add_argument("--number_of_prompts", default=3, type=int, help="The number of tool descriptions.") +parser.add_argument("--num_questions", default=3, type=int, help="The number of few shot questions.") +args = parser.parse_args() +# yapf: enable + + +def generate_candidate_prompts(description, number_of_prompts): + prompts = [] + for i in range(number_of_prompts): + messages = [create_description(description)] + results = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token)["result"] + prompts.append(results) + return prompts + + +def generate_candidate_questions(description, num_questions=5): + messages = [create_questions(description, num_questions=num_questions)] + results = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token)["result"] + return results + + +if __name__ == "__main__": + erniebot.api_type = args.api_type + erniebot.access_token = args.access_token + + embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") + faiss = FAISS.load_local("城市管理执法办法", embeddings) + + list_data = read_data(args.summarization_path) + doc = list_data[0] + tool_descriptions = generate_candidate_prompts(doc["abstract"], number_of_prompts=args.number_of_prompts) + print(tool_descriptions) + + questions = generate_candidate_questions(doc["abstract"], num_questions=args.num_questions).split("\n") + print(questions) + + prompts = tool_descriptions + prompt_results = {prompt: {"correct": 0.0, "total": 0.0} for prompt in prompts} + + # Initialize the table + table = PrettyTable() + table_field_names = ["Prompt"] + [ + f"question {i+1}-{j+1}" for j, prompt in enumerate(questions) for i in range(questions.count(prompt)) + ] + table.field_names = table_field_names + + # Wrap the text in the "Prompt" column + table.max_width["Prompt"] = 100 + + llm = ERNIEBot(model="ernie-bot") + best_prompt = None + best_percentage = 0.0 + for i, tool_description in enumerate(tool_descriptions): + openai_city_management = OpenAISearchTool( + name="city_administrative_law_enforcement", + description=tool_description, + db=faiss, + threshold=0.1, + ) + row = [tool_description] + resps = [] + for query in questions: + agent = PromptAgent(memory=WholeMemory(), llm=llm, tools=[openai_city_management]) + response = asyncio.run(agent.async_run(query)) + resps.append(response) + if response is True: + prompt_results[tool_description]["correct"] += 1 + row.append("✅") + else: + row.append("❌") + prompt_results[tool_description]["total"] += 1 + table.add_row(row) + + print(table) + for i, prompt in enumerate(prompts): + correct = prompt_results[prompt]["correct"] + total = prompt_results[prompt]["total"] + percentage = (correct / total) * 100 + print(f"Prompt {i+1} got {percentage:.2f}% correct.") + if percentage > best_percentage: + best_percentage = percentage + best_prompt = tool_description + + print(f"The best prompt was '{best_prompt}' with a correctness of {best_percentage:.2f}%.") diff --git a/erniebot-agent/examples/utils.py b/erniebot-agent/examples/utils.py index 6a5600631..fcb0803d2 100644 --- a/erniebot-agent/examples/utils.py +++ b/erniebot-agent/examples/utils.py @@ -29,6 +29,38 @@ def create_abstract(chunk: str) -> Dict[str, str]: } +def create_questions(chunk: str, num_questions: int = 5) -> Dict[str, str]: + """Create a message for the chat completion + + Args: + chunk (str): The chunk of text to summarize + question (str): The question to answer + + Returns: + Dict[str, str]: The message to send to the chat completion + """ + return { + "role": "user", + "content": f"""{chunk},请根据上面的摘要,生成{num_questions}个问题,问题内容和形式要多样化,分条列举出来.""", + } + + +def create_description(chunk: str) -> Dict[str, str]: + """Create a message for the chat completion + + Args: + chunk (str): The chunk of text to summarize + question (str): The question to answer + + Returns: + Dict[str, str]: The message to send to the chat completion + """ + return { + "role": "user", + "content": f"""{chunk},请根据上面的摘要,生成一个简短的描述,不超过30字.""", + } + + def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]: """Split text into chunks of a maximum length @@ -60,13 +92,13 @@ def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]: def erniebot_chat( - messages, model="ernie-bot-8k", api_type="aistudio", access_token=None, functions=None, **kwargs + messages, model="ernie-bot", api_type="aistudio", access_token=None, functions=None, **kwargs ): """ Args: messages: dict or list, 输入的消息(message) model: str, 模型名称 - api_type: str, 接口类型,可选值包括 'aistudio' 和 'webchat' + api_type: str, 接口类型,可选值包括 'aistudio' 和 'qianfan' access_token: str, 访问令牌(access token) functions: list, 函数列表 kwargs: 其他参数 From c25e6aa696c0722c5ea5ef1d466b08a7f10036ad Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 21 Dec 2023 02:47:04 +0000 Subject: [PATCH 19/43] Update prompt engineer --- .../examples/automatic_prompt_engineer.py | 34 +++++++++++++++---- erniebot-agent/examples/utils.py | 19 +++++------ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/erniebot-agent/examples/automatic_prompt_engineer.py b/erniebot-agent/examples/automatic_prompt_engineer.py index 7105c62bf..685c197c0 100644 --- a/erniebot-agent/examples/automatic_prompt_engineer.py +++ b/erniebot-agent/examples/automatic_prompt_engineer.py @@ -8,7 +8,13 @@ from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores import FAISS from prettytable import PrettyTable -from utils import create_description, create_questions, erniebot_chat, read_data +from utils import ( + create_description, + create_keywords, + create_questions, + erniebot_chat, + read_data, +) import erniebot @@ -19,6 +25,7 @@ parser.add_argument("--summarization_path", default='data/data.jsonl', type=str, help="The output path.") parser.add_argument("--number_of_prompts", default=3, type=int, help="The number of tool descriptions.") parser.add_argument("--num_questions", default=3, type=int, help="The number of few shot questions.") +parser.add_argument("--num_keywords", default=-1, type=int, help="The number of few shot questions.") args = parser.parse_args() # yapf: enable @@ -27,14 +34,24 @@ def generate_candidate_prompts(description, number_of_prompts): prompts = [] for i in range(number_of_prompts): messages = [create_description(description)] - results = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token)["result"] + results = erniebot_chat( + messages, model="ernie-bot-4", api_type=args.api_type, access_token=args.access_token + )["result"] prompts.append(results) return prompts -def generate_candidate_questions(description, num_questions=5): - messages = [create_questions(description, num_questions=num_questions)] - results = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token)["result"] +def generate_candidate_questions( + description, num_questions: int = -1, num_keywords: int = -1, temperature=1e-10 +): + if num_questions > 0: + messages = [create_questions(description, num_questions=num_questions)] + elif num_keywords > 0: + messages = [create_keywords(description, num_keywords=num_keywords)] + + results = erniebot_chat( + messages, api_type=args.api_type, access_token=args.access_token, temperature=temperature + )["result"] return results @@ -48,10 +65,14 @@ def generate_candidate_questions(description, num_questions=5): list_data = read_data(args.summarization_path) doc = list_data[0] tool_descriptions = generate_candidate_prompts(doc["abstract"], number_of_prompts=args.number_of_prompts) + tool_descriptions = list(set(tool_descriptions)) print(tool_descriptions) questions = generate_candidate_questions(doc["abstract"], num_questions=args.num_questions).split("\n") - print(questions) + + if args.num_keywords > 0: + keywords = generate_candidate_questions(doc["abstract"], num_keywords=args.num_keywords).split("\n") + questions += keywords prompts = tool_descriptions prompt_results = {prompt: {"correct": 0.0, "total": 0.0} for prompt in prompts} @@ -90,6 +111,7 @@ def generate_candidate_questions(description, num_questions=5): prompt_results[tool_description]["total"] += 1 table.add_row(row) + print(f"生成的问题如下:{questions}") print(table) for i, prompt in enumerate(prompts): correct = prompt_results[prompt]["correct"] diff --git a/erniebot-agent/examples/utils.py b/erniebot-agent/examples/utils.py index fcb0803d2..25a373c56 100644 --- a/erniebot-agent/examples/utils.py +++ b/erniebot-agent/examples/utils.py @@ -29,19 +29,18 @@ def create_abstract(chunk: str) -> Dict[str, str]: } -def create_questions(chunk: str, num_questions: int = 5) -> Dict[str, str]: - """Create a message for the chat completion +def create_questions(chunk: str, num_questions: int = 5): + return { + "role": "user", + "content": f"""{chunk},请根据上面的摘要,生成{num_questions}个问题,问题内容和形式要多样化,口语化,不允许重复,分条列举出来.""", + } - Args: - chunk (str): The chunk of text to summarize - question (str): The question to answer - Returns: - Dict[str, str]: The message to send to the chat completion - """ +def create_keywords(chunk: str, num_keywords: int = 3): return { "role": "user", - "content": f"""{chunk},请根据上面的摘要,生成{num_questions}个问题,问题内容和形式要多样化,分条列举出来.""", + "content": f"""{chunk},请根据上面的摘要,抽取{num_keywords}个关键字或者简短关键句,分条列举出来。 + 要求:只需要输出关键字或者简短的关键句,不需要输出其它的内容.""", } @@ -57,7 +56,7 @@ def create_description(chunk: str) -> Dict[str, str]: """ return { "role": "user", - "content": f"""{chunk},请根据上面的摘要,生成一个简短的描述,不超过30字.""", + "content": f"""{chunk},请根据上面的摘要,生成一个简短的描述,不超过40字.""", } From ff980bac8a4db4a74bb4d7c7bf5ffe279e700b22 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 21 Dec 2023 02:53:25 +0000 Subject: [PATCH 20/43] Add prompt agent --- .../erniebot_agent/agents/prompt_agent.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 erniebot-agent/erniebot_agent/agents/prompt_agent.py diff --git a/erniebot-agent/erniebot_agent/agents/prompt_agent.py b/erniebot-agent/erniebot_agent/agents/prompt_agent.py new file mode 100644 index 000000000..ad4f94951 --- /dev/null +++ b/erniebot-agent/erniebot_agent/agents/prompt_agent.py @@ -0,0 +1,39 @@ +from typing import Any, List, Optional + +from erniebot_agent.agents import FunctionalAgent +from erniebot_agent.agents.schema import AgentAction, AgentFile +from erniebot_agent.file_io.base import File +from erniebot_agent.messages import HumanMessage, Message + + +class PromptAgent(FunctionalAgent): + def __init__(self, top_k: int = 2, threshold: float = 0.1, token_limit: int = 3000, **kwargs): + super().__init__(**kwargs) + self.top_k = top_k + self.threshold = threshold + self.token_limit = token_limit + # self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。") + + async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> Any: + actions_taken: List[AgentAction] = [] + files_involved: List[AgentFile] = [] + chat_history: List[Message] = [] + + next_step_input = HumanMessage(content=prompt) + curr_step_output = await self._async_step( + next_step_input, chat_history, actions_taken, files_involved + ) + return curr_step_output + + async def _async_step( + self, + step_input, + chat_history: List[Message], + actions: List[AgentAction], + files: List[AgentFile], + ) -> Optional[Any]: + maybe_action = await self._async_plan(step_input, chat_history) + if isinstance(maybe_action, AgentAction): + return True + else: + return False From f6ba05b5a0a50e8f38c9c357d11909c5b04f6a83 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 16 Jan 2024 11:17:07 +0000 Subject: [PATCH 21/43] Resolve conflicts --- .../src/erniebot_agent/agents/__init__.py | 3 +- .../agents/function_agent_with_retrieval.py | 80 +++++++++++-------- 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/__init__.py b/erniebot-agent/src/erniebot_agent/agents/__init__.py index 1515c0b39..20909f669 100644 --- a/erniebot-agent/src/erniebot_agent/agents/__init__.py +++ b/erniebot-agent/src/erniebot_agent/agents/__init__.py @@ -15,10 +15,9 @@ from erniebot_agent.agents.agent import Agent from erniebot_agent.agents.function_agent import FunctionAgent from erniebot_agent.agents.function_agent_with_retrieval import ( + ContextAugmentedFunctionalAgent, FunctionAgentWithRetrieval, FunctionAgentWithRetrievalScoreTool, FunctionAgentWithRetrievalTool, - ContextAugmentedFunctionalAgent, FunctionalAgentWithQueryPlanning, - ) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index 23156f688..8a84a7f3f 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -9,6 +9,7 @@ DEFAULT_FINISH_STEP, AgentResponse, AgentStep, + AgentStepWithFiles, EndStep, File, PluginStep, @@ -24,14 +25,7 @@ Message, SearchInfo, ) -from erniebot_agent.file_io.base import File -from erniebot_agent.messages import ( - AIMessage, - FunctionMessage, - HumanMessage, - Message, - SystemMessage, -) +from erniebot_agent.messages import SystemMessage from erniebot_agent.prompt import PromptTemplate from erniebot_agent.retrieval import BaizhongSearch from erniebot_agent.tools.base import Tool @@ -343,24 +337,34 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) # Direct Prompt next_step_input = HumanMessage(content=f"问题:{prompt},要求:请在第一步执行检索的操作,并且检索只允许调用一次") + chat_history.append(next_step_input) tool_resp = ToolResponse(json=tool_ret_json, files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) num_steps_taken = 0 while num_steps_taken < self.max_steps: - curr_step_output = await self._async_step( - next_step_input, chat_history, actions_taken, files_involved - ) - if curr_step_output is None: - response = self._create_finished_response(chat_history, actions_taken, files_involved) + curr_step, new_messages = await self._step(chat_history) + chat_history.extend(new_messages) + if isinstance(curr_step, ToolStep): + steps_taken.append(curr_step) + + elif isinstance(curr_step, PluginStep): + steps_taken.append(curr_step) + # 预留 调用了Plugin之后不结束的接口 + + # 此处为调用了Plugin之后直接结束的Plugin + curr_step = DEFAULT_FINISH_STEP + + if isinstance(curr_step, EndStep): # plugin with action + response = self._create_finished_response(chat_history, steps_taken, curr_step=curr_step) self.memory.add_message(chat_history[0]) self.memory.add_message(chat_history[-1]) return response num_steps_taken += 1 - response = self._create_stopped_response(chat_history, actions_taken, files_involved) + response = self._create_stopped_response(chat_history, steps_taken) return response else: - logger.info( + _logger.info( f"Irrelevant retrieval results. Fallbacking to FunctionalAgent for the query: {prompt}" ) return await super().run(prompt, files) @@ -376,7 +380,7 @@ async def _maybe_retrieval( return results -class ContextAugmentedFunctionalAgent(FunctionalAgent): +class ContextAugmentedFunctionalAgent(FunctionAgent): def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) self.knowledge_base = knowledge_base @@ -385,13 +389,12 @@ def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: fl self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) self.search_tool = KnowledgeBaseTool() - async def _run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: + async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: results = await self._maybe_retrieval(prompt) if len(results["documents"]) > 0: # RAG chat_history: List[Message] = [] - actions_taken: List[AgentAction] = [] - files_involved: List[AgentFile] = [] + steps_taken: List[AgentStep] = [] tool_args = json.dumps({"query": prompt}, ensure_ascii=False) await self._callback_manager.on_tool_start( @@ -494,7 +497,7 @@ async def _maybe_retrieval( 请根据以上检索结果回答检索语句的问题""" -class FunctionalAgentWithQueryPlanning(FunctionalAgent): +class FunctionalAgentWithQueryPlanning(FunctionAgent): def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) self.top_k = top_k @@ -504,33 +507,42 @@ def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwa self.knowledge_base = knowledge_base self.rag_prompt = PromptTemplate(OPENAI_RAG_PROMPT, input_variables=["documents", "query"]) - async def _run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: - # chat_history: List[Message] = [] - actions_taken: List[AgentAction] = [] - files_involved: List[AgentFile] = [] + async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: chat_history: List[Message] = [] + steps_taken: List[AgentStep] = [] # 会有无限循环调用工具的问题 # next_step_input = HumanMessage( # content=f"请选择合适的工具来回答:{prompt},如果需要的话,可以对把问题分解成子问题,然后每个子问题选择合适的工具回答。" # ) - next_step_input = HumanMessage(content=prompt) + run_input = await HumanMessage.create_with_files( + prompt, files or [], include_file_urls=self.file_needs_url + ) + chat_history.append(run_input) num_steps_taken = 0 while num_steps_taken < self.max_steps: - curr_step_output = await self._async_step( - next_step_input, chat_history, actions_taken, files_involved - ) - if curr_step_output is None: - response = self._create_finished_response(chat_history, actions_taken, files_involved) + curr_step, new_messages = await self._step(chat_history) + chat_history.extend(new_messages) + if isinstance(curr_step, ToolStep): + steps_taken.append(curr_step) + + elif isinstance(curr_step, PluginStep): + steps_taken.append(curr_step) + # 预留 调用了Plugin之后不结束的接口 + + # 此处为调用了Plugin之后直接结束的Plugin + curr_step = DEFAULT_FINISH_STEP + + if isinstance(curr_step, EndStep): + response = self._create_finished_response(chat_history, steps_taken, curr_step) self.memory.add_message(chat_history[0]) self.memory.add_message(chat_history[-1]) return response num_steps_taken += 1 # TODO(wugaosheng): Add manual planning and execute - # response = self._create_stopped_response(chat_history, actions_taken, files_involved) - return await self.plan_and_execute(prompt, actions_taken, files_involved) + return await self.plan_and_execute(prompt, steps_taken, curr_step) - async def plan_and_execute(self, prompt, actions_taken, files_involved): + async def plan_and_execute(self, prompt, steps_taken: List[AgentStep], curr_step: AgentStepWithFiles): step_input = HumanMessage(content=self.query_transform.format(prompt=prompt)) fake_chat_history: List[Message] = [step_input] llm_resp = await self._async_run_llm_without_hooks( @@ -563,7 +575,7 @@ async def plan_and_execute(self, prompt, actions_taken, files_involved): output_message = llm_resp.message chat_history.append(output_message) - response = self._create_finished_response(chat_history, actions_taken, files_involved) + response = self._create_finished_response(chat_history, steps_taken, curr_step) self.memory.add_message(chat_history[0]) self.memory.add_message(chat_history[-1]) return response From e37f2bfd7226f168b37bf8a6d388772f2032f730 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 16 Jan 2024 11:51:00 +0000 Subject: [PATCH 22/43] Fix conflicts --- .../agents/function_agent_with_retrieval.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index 8a84a7f3f..4d0af1c08 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -25,7 +25,6 @@ Message, SearchInfo, ) -from erniebot_agent.messages import SystemMessage from erniebot_agent.prompt import PromptTemplate from erniebot_agent.retrieval import BaizhongSearch from erniebot_agent.tools.base import Tool @@ -338,7 +337,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age # Direct Prompt next_step_input = HumanMessage(content=f"问题:{prompt},要求:请在第一步执行检索的操作,并且检索只允许调用一次") chat_history.append(next_step_input) - tool_resp = ToolResponse(json=tool_ret_json, files=[]) + tool_resp = ToolResponse(json=tool_ret_json, input_files=[], output_files=[]) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) num_steps_taken = 0 @@ -405,10 +404,8 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) fake_chat_history: List[Message] = [] fake_chat_history.append(step_input) - llm_resp = await self._async_run_llm_without_hooks( - messages=fake_chat_history, - functions=None, - system=self.system_message.content if self.system_message is not None else None, + llm_resp = await self.run_llm( + messages=fake_chat_history ) # Get RAG results @@ -428,12 +425,13 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age next_step_input = HumanMessage( content=f"背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" ) + chat_history.append(next_step_input) # Knowledge Retrieval Tool action = ToolAction(tool_name=self.search_tool.tool_name, tool_args=tool_args) # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json) - chat_history.append(next_step_input) + # next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json) + tool_resp = ToolResponse(json=tool_ret_json, input_files=[], output_files=[]) steps_taken.append( ToolStep( @@ -502,7 +500,6 @@ def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwa super().__init__(**kwargs) self.top_k = top_k self.threshold = threshold - self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。") self.query_transform = PromptTemplate(QUERY_DECOMPOSITION, input_variables=["prompt"]) self.knowledge_base = knowledge_base self.rag_prompt = PromptTemplate(OPENAI_RAG_PROMPT, input_variables=["documents", "query"]) @@ -542,13 +539,11 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age # TODO(wugaosheng): Add manual planning and execute return await self.plan_and_execute(prompt, steps_taken, curr_step) - async def plan_and_execute(self, prompt, steps_taken: List[AgentStep], curr_step: AgentStepWithFiles): + async def plan_and_execute(self, prompt, steps_taken: List[AgentStep], curr_step: AgentStep): step_input = HumanMessage(content=self.query_transform.format(prompt=prompt)) fake_chat_history: List[Message] = [step_input] - llm_resp = await self._async_run_llm_without_hooks( + llm_resp = await self.run_llm( messages=fake_chat_history, - functions=None, - system=self.system_message.content if self.system_message is not None else None, ) output_message = llm_resp.message @@ -567,15 +562,19 @@ async def plan_and_execute(self, prompt, steps_taken: List[AgentStep], curr_step content=self.rag_prompt.format(query=prompt, documents=retrieval_results[:3]) ) chat_history: List[Message] = [step_input] - llm_resp = await self._async_run_llm_without_hooks( - messages=chat_history, - functions=None, - system=self.system_message.content if self.system_message is not None else None, + llm_resp = await self.run_llm( + messages=chat_history ) output_message = llm_resp.message chat_history.append(output_message) - response = self._create_finished_response(chat_history, steps_taken, curr_step) + last_message = chat_history[-1] + response = AgentResponse( + text=last_message.content, + chat_history=chat_history, + steps=steps_taken, + status="FINISHED", + ) self.memory.add_message(chat_history[0]) self.memory.add_message(chat_history[-1]) return response From 6b95d87e2a99ed83defc73de8a3b66db03dff266 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 16 Jan 2024 11:53:27 +0000 Subject: [PATCH 23/43] Update retrieval tools --- .../agents/function_agent_with_retrieval.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index 4d0af1c08..30f2d53a6 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -9,7 +9,6 @@ DEFAULT_FINISH_STEP, AgentResponse, AgentStep, - AgentStepWithFiles, EndStep, File, PluginStep, @@ -404,9 +403,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) fake_chat_history: List[Message] = [] fake_chat_history.append(step_input) - llm_resp = await self.run_llm( - messages=fake_chat_history - ) + llm_resp = await self.run_llm(messages=fake_chat_history) # Get RAG results output_message = llm_resp.message @@ -427,11 +424,11 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) chat_history.append(next_step_input) # Knowledge Retrieval Tool - action = ToolAction(tool_name=self.search_tool.tool_name, tool_args=tool_args) + # action = ToolAction(tool_name=self.search_tool.tool_name, tool_args=tool_args) # return response tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) # next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json) - + tool_resp = ToolResponse(json=tool_ret_json, input_files=[], output_files=[]) steps_taken.append( ToolStep( @@ -562,9 +559,7 @@ async def plan_and_execute(self, prompt, steps_taken: List[AgentStep], curr_step content=self.rag_prompt.format(query=prompt, documents=retrieval_results[:3]) ) chat_history: List[Message] = [step_input] - llm_resp = await self.run_llm( - messages=chat_history - ) + llm_resp = await self.run_llm(messages=chat_history) output_message = llm_resp.message chat_history.append(output_message) From 254df32fb95020093ca375d623502465eb7ad5f2 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 16 Jan 2024 11:57:25 +0000 Subject: [PATCH 24/43] Update name --- erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py b/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py index e1e4ac152..eaaeee425 100644 --- a/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py @@ -31,7 +31,7 @@ class BaizhongSearchTool(Tool): ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView def __init__( - self, name, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None + self, description, db, threshold: float = 0.0, name: Optional[str]=None,input_type=None, output_type=None, examples=None ) -> None: super().__init__() self.name = name From 1b059d66c48a109ea783f9227968ee5e8ac244c2 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 16 Jan 2024 12:00:12 +0000 Subject: [PATCH 25/43] Update format --- erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py b/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py index eaaeee425..9b1814511 100644 --- a/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py @@ -31,7 +31,14 @@ class BaizhongSearchTool(Tool): ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView def __init__( - self, description, db, threshold: float = 0.0, name: Optional[str]=None,input_type=None, output_type=None, examples=None + self, + description, + db, + threshold: float = 0.0, + name: Optional[str] = None, + input_type=None, + output_type=None, + examples=None, ) -> None: super().__init__() self.name = name From add83cdca0d7bcc60cd8839248a8caccc1986a91 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 17 Jan 2024 08:34:53 +0000 Subject: [PATCH 26/43] Update unitests --- .../agents/function_agent_with_retrieval.py | 4 +-- ...unction_agent_with_retrieval_score_tool.py | 34 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index 30f2d53a6..4feb6ef9e 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -365,7 +365,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age _logger.info( f"Irrelevant retrieval results. Fallbacking to FunctionalAgent for the query: {prompt}" ) - return await super().run(prompt, files) + return await super()._run(prompt, files) async def _maybe_retrieval( self, @@ -466,7 +466,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age _logger.info( f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}" ) - return await super()._run(prompt) + return await super()._run(prompt, files) async def _maybe_retrieval( self, diff --git a/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py b/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py index 83e737050..a70da5f37 100644 --- a/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py +++ b/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py @@ -120,25 +120,25 @@ async def test_functional_agent_with_retrieval_retrieval_score_tool_run_retrieva assert response.text == "Text response" # HumanMessage - assert response.chat_history[0].content == "Hello, world!" - # AIMessage - assert response.chat_history[1].function_call == { - "name": "KnowledgeBaseTool", - "thoughts": "这是一个检索的需求,我需要在KnowledgeBaseTool知识库中检索出与输入的query相关的段落,并返回给用户。", - "arguments": '{"query": "Hello, world!"}', - } - - # FunctionMessag - assert response.chat_history[2].name == "KnowledgeBaseTool" - assert ( - response.chat_history[2].content - == '{"documents": [{"id": "495735246643269", "title": "城市管理执法办法.pdf", ' - '"document": "住房和城乡建设部规章城市管理执法办法"}, {"id": "495735246643270", ' - '"title": "城市管理执法办法.pdf", "document": "城市管理执法主管部门应当定期开展执法人员的培训和考核。"}]}' - ) + assert response.chat_history[0].content == "问题:Hello, world!,要求:请在第一步执行检索的操作,并且检索只允许调用一次" + # # AIMessage + # assert response.chat_history[1].function_call == { + # "name": "KnowledgeBaseTool", + # "thoughts": "这是一个检索的需求,我需要在KnowledgeBaseTool知识库中检索出与输入的query相关的段落,并返回给用户。", + # "arguments": '{"query": "Hello, world!"}', + # } + + # # FunctionMessag + # assert response.chat_history[2].name == "KnowledgeBaseTool" + # assert ( + # response.chat_history[2].content + # == '{"documents": [{"id": "495735246643269", "title": "城市管理执法办法.pdf", ' + # '"document": "住房和城乡建设部规章城市管理执法办法"}, {"id": "495735246643270", ' + # '"title": "城市管理执法办法.pdf", "document": "城市管理执法主管部门应当定期开展执法人员的培训和考核。"}]}' + # ) # AIMessage - assert response.chat_history[3].content == "Text response" + assert response.chat_history[1].content == "Text response" # Test retrieval failed From b91517662a100bf3f80daeee59f07bf4b5c3dcb6 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 17 Jan 2024 08:38:29 +0000 Subject: [PATCH 27/43] remove functional_agent_with_retrieval_example.py --- ...functional_agent_with_retrieval_example.py | 175 ------------------ 1 file changed, 175 deletions(-) delete mode 100644 erniebot-agent/examples/functional_agent_with_retrieval_example.py diff --git a/erniebot-agent/examples/functional_agent_with_retrieval_example.py b/erniebot-agent/examples/functional_agent_with_retrieval_example.py deleted file mode 100644 index dfcdd5c04..000000000 --- a/erniebot-agent/examples/functional_agent_with_retrieval_example.py +++ /dev/null @@ -1,175 +0,0 @@ -import argparse -import asyncio -from typing import Dict, List, Type - -from erniebot_agent.agents import ( - ContextAugmentedFunctionalAgent, - FunctionalAgentWithRetrieval, - FunctionalAgentWithRetrievalScoreTool, - FunctionalAgentWithRetrievalTool, -) -from erniebot_agent.chat_models import ERNIEBot -from erniebot_agent.memory import WholeMemory -from erniebot_agent.messages import AIMessage, HumanMessage, Message -from erniebot_agent.retrieval import BaizhongSearch -from erniebot_agent.retrieval.document import Document -from erniebot_agent.tools.baizhong_tool import BaizhongSearchTool -from erniebot_agent.tools.base import RemoteToolkit, Tool -from erniebot_agent.tools.schema import ToolParameterView -from langchain.document_loaders import PyPDFDirectoryLoader -from langchain.text_splitter import SpacyTextSplitter -from pydantic import Field -from tqdm import tqdm - -import erniebot - -parser = argparse.ArgumentParser() -parser.add_argument("--base_url", type=str, help="The Aurora serving path.") -parser.add_argument("--data_path", default="construction_regulations", type=str, help="The data path.") -parser.add_argument( - "--access_token", default="ai_studio_access_token", type=str, help="The aistudio access token." -) -parser.add_argument("--api_type", default="qianfan", type=str, help="The aistudio access token.") -parser.add_argument("--api_key", default="", type=str, help="The API Key.") -parser.add_argument("--secret_key", default="", type=str, help="The secret key.") -parser.add_argument("--indexing", action="store_true", help="The indexing step.") -parser.add_argument("--project_id", default=-1, type=int, help="The API Key.") -parser.add_argument( - "--retrieval_type", - choices=["rag", "rag_tool", "rag_threshold", "context_aug"], - default="rag", - help="Retrieval type, default to rag.", -) -args = parser.parse_args() - - -class NotesToolInputView(ToolParameterView): - draft: str = Field(description="草稿文本") - - -class NotesToolOutputView(ToolParameterView): - draft_results: str = Field(description="草稿文本结果") - - -class NotesTool(Tool): - description: str = "笔记本,用于记录和保存信息的笔记本工具" - input_type: Type[ToolParameterView] = NotesToolInputView - ouptut_type: Type[ToolParameterView] = NotesToolOutputView - - async def __call__(self, draft: str) -> Dict[str, str]: - # TODO: save draft to database - return {"draft_results": "草稿在笔记本中保存成功"} - - @property - def examples(self) -> List[Message]: - return [ - HumanMessage("OpenAI管理层变更会带来哪些影响?并请把搜索的内容添加到笔记本中"), - AIMessage( - "", - function_call={ - "name": self.tool_name, - "thoughts": f"用户想保存笔记,我可以使用{self.tool_name}工具来保存,其中`draft`字段的内容为:'搜索的草稿'。", - "arguments": '{"draft": "搜索的草稿"}', - }, - ), - ] - - -def offline_ann(data_path, baizhong_db): - loader = PyPDFDirectoryLoader(data_path) - documents = loader.load() - text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) - docs = text_splitter.split_documents(documents) - list_data = [] - for item in tqdm(docs): - doc_title = item.metadata["source"].split("/")[-1] - doc_content = item.page_content - example = {"title": doc_title, "content_se": doc_content} - example = Document.from_dict(example) - list_data.append(example) - res = baizhong_db.add_documents(documents=list_data) - return res - - -if __name__ == "__main__": - erniebot.api_type = args.api_type - erniebot.access_token = args.access_token - baizhong_db = BaizhongSearch( - base_url=args.base_url, - project_name="construct_assistant2", - remark="construction assistant test dataset", - project_id=args.project_id if args.project_id != -1 else None, - ) - print(baizhong_db.project_id) - if args.indexing: - res = offline_ann(args.data_path, baizhong_db) - print(res) - - llm = ERNIEBot(model="ernie-bot", api_type="custom") - - retrieval_tool = BaizhongSearchTool(description="在知识库中检索相关的段落", db=baizhong_db, threshold=0.1) - - # agent = FunctionalAgentWithRetrievalTool( - # llm=llm, knowledge_base=baizhong_db, top_k=3, tools=[NotesTool(), retrieval_tool], memory=memory - # ) - - # queries = [ - # "请把飞桨这两个字添加到笔记本中", - # "OpenAI管理层变更会带来哪些影响?并请把搜索的内容添加到笔记本中", - # "OpenAI管理层变更会带来哪些影响?", - # "量化交易", - # "今天天气怎么样?", - # "abcabc", - # ] - - queries = [ - # "量化交易", - # "城市景观照明中有过度照明的规定是什么?", - "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", - # "这几篇文档主要内容是什么?", - # "今天天气怎么样?", - # "abcabc", - ] - toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") - for query in queries: - memory = WholeMemory() - if args.retrieval_type == "rag": - agent = FunctionalAgentWithRetrieval( # type: ignore - llm=llm, - knowledge_base=baizhong_db, - top_k=3, - tools=toolkit.get_tools() + [NotesTool(), retrieval_tool], - memory=memory, - ) - elif args.retrieval_type == "rag_tool": - agent = FunctionalAgentWithRetrievalTool( # type: ignore - llm=llm, - knowledge_base=baizhong_db, - top_k=3, - tools=toolkit.get_tools() + [retrieval_tool], - memory=memory, - ) - elif args.retrieval_type == "rag_threshold": - agent = FunctionalAgentWithRetrievalScoreTool( # type: ignore - llm=llm, - knowledge_base=baizhong_db, - top_k=3, - threshold=0.1, - tools=[NotesTool(), retrieval_tool], - memory=memory, - ) - elif args.retrieval_type == "context_aug": - agent = ContextAugmentedFunctionalAgent( # type: ignore - llm=llm, - knowledge_base=baizhong_db, - top_k=3, - threshold=0.1, - tools=[NotesTool(), retrieval_tool], - memory=memory, - ) - try: - response = asyncio.run(agent.async_run(query)) - print(f"query: {query}") - print(f"agent response: {response}") - except Exception as e: - print(e) From d4a034d15729e540126be2b13f67952427e871eb Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 17 Jan 2024 08:42:17 +0000 Subject: [PATCH 28/43] Update function_agent_with_retrieval.py --- .../agents/function_agent_with_retrieval.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index 4feb6ef9e..feef359c6 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -439,29 +439,8 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) ) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) - - num_steps_taken = 0 - while num_steps_taken < self.max_steps: - curr_step, new_messages = await self._step(chat_history) - chat_history.extend(new_messages) - if isinstance(curr_step, ToolStep): - steps_taken.append(curr_step) - - elif isinstance(curr_step, PluginStep): - steps_taken.append(curr_step) - # 预留 调用了Plugin之后不结束的接口 - - # 此处为调用了Plugin之后直接结束的Plugin - curr_step = DEFAULT_FINISH_STEP - - if isinstance(curr_step, EndStep): # plugin with action - response = self._create_finished_response(chat_history, steps_taken, curr_step=curr_step) - self.memory.add_message(chat_history[0]) - self.memory.add_message(chat_history[-1]) - return response - num_steps_taken += 1 - response = self._create_stopped_response(chat_history, steps_taken) - return response + rewrite_prompt = "背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" + return super()._run(rewrite_prompt, files) else: _logger.info( f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}" From 30d8081ff0827904864c32a7d2e12ac699f72c57 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 17 Jan 2024 08:42:40 +0000 Subject: [PATCH 29/43] Update ContextAugmentedFunctionalAgent --- .../src/erniebot_agent/agents/function_agent_with_retrieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index feef359c6..e81570f9c 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -440,7 +440,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) rewrite_prompt = "背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" - return super()._run(rewrite_prompt, files) + return super()._run(rewrite_prompt, files) else: _logger.info( f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}" From d97ee9d1b11374ee65b39561507bdf35da84d979 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 17 Jan 2024 09:09:17 +0000 Subject: [PATCH 30/43] Update prompt agent --- .../erniebot_agent/agents/prompt_agent.py | 39 -------- .../tools/openai_search_tool.py | 89 ------------------- .../src/erniebot_agent/agents/prompt_agent.py | 41 +++++++++ 3 files changed, 41 insertions(+), 128 deletions(-) delete mode 100644 erniebot-agent/erniebot_agent/agents/prompt_agent.py delete mode 100644 erniebot-agent/erniebot_agent/tools/openai_search_tool.py create mode 100644 erniebot-agent/src/erniebot_agent/agents/prompt_agent.py diff --git a/erniebot-agent/erniebot_agent/agents/prompt_agent.py b/erniebot-agent/erniebot_agent/agents/prompt_agent.py deleted file mode 100644 index ad4f94951..000000000 --- a/erniebot-agent/erniebot_agent/agents/prompt_agent.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Any, List, Optional - -from erniebot_agent.agents import FunctionalAgent -from erniebot_agent.agents.schema import AgentAction, AgentFile -from erniebot_agent.file_io.base import File -from erniebot_agent.messages import HumanMessage, Message - - -class PromptAgent(FunctionalAgent): - def __init__(self, top_k: int = 2, threshold: float = 0.1, token_limit: int = 3000, **kwargs): - super().__init__(**kwargs) - self.top_k = top_k - self.threshold = threshold - self.token_limit = token_limit - # self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。") - - async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> Any: - actions_taken: List[AgentAction] = [] - files_involved: List[AgentFile] = [] - chat_history: List[Message] = [] - - next_step_input = HumanMessage(content=prompt) - curr_step_output = await self._async_step( - next_step_input, chat_history, actions_taken, files_involved - ) - return curr_step_output - - async def _async_step( - self, - step_input, - chat_history: List[Message], - actions: List[AgentAction], - files: List[AgentFile], - ) -> Optional[Any]: - maybe_action = await self._async_plan(step_input, chat_history) - if isinstance(maybe_action, AgentAction): - return True - else: - return False diff --git a/erniebot-agent/erniebot_agent/tools/openai_search_tool.py b/erniebot-agent/erniebot_agent/tools/openai_search_tool.py deleted file mode 100644 index 2beb06296..000000000 --- a/erniebot-agent/erniebot_agent/tools/openai_search_tool.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Type - -from erniebot_agent.messages import AIMessage, HumanMessage -from erniebot_agent.tools.schema import ToolParameterView -from pydantic import Field - -from .base import Tool - - -class OpenAISearchToolInputView(ToolParameterView): - query: str = Field(description="查询语句") - top_k: int = Field(description="返回结果数量") - - -class SearchResponseDocument(ToolParameterView): - title: str = Field(description="检索结果的标题") - document: str = Field(description="检索结果的内容") - - -class OpenAISearchToolOutputView(ToolParameterView): - documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落") - - -class OpenAISearchTool(Tool): - description: str = "在知识库中检索与用户输入query相关的段落" - input_type: Type[ToolParameterView] = OpenAISearchToolInputView - ouptut_type: Type[ToolParameterView] = OpenAISearchToolOutputView - - def __init__( - self, - name, - description, - db, - threshold: float = 0.0, - input_type=None, - output_type=None, - examples=None, - return_meta_data: bool = False, - ) -> None: - super().__init__() - self.name = name - self.db = db - self.description = description - self.return_meta_data = return_meta_data - self.few_shot_examples = [] - if input_type is not None: - self.input_type = input_type - if output_type is not None: - self.ouptut_type = output_type - if examples is not None: - self.few_shot_examples = examples - self.threshold = threshold - - async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None): - documents = self.db.similarity_search_with_relevance_scores(query, top_k) - docs = [] - for doc, score in documents: - if score > self.threshold: - new_doc = {"document": doc.page_content} - if self.return_meta_data: - new_doc["meta"] = doc.metadata - if "source" in doc.metadata: - new_doc["title"] = doc.metadata["source"] - - docs.append(new_doc) - - return {"documents": docs} - - @property - def examples( - self, - ) -> List[Any]: - few_shot_objects: List[Any] = [] - for item in self.few_shot_examples: - few_shot_objects.append(HumanMessage(item["user"])) - few_shot_objects.append( - AIMessage( - "", - function_call={ - "name": self.tool_name, - "thoughts": item["thoughts"], - "arguments": item["arguments"], - }, - ) - ) - - return few_shot_objects diff --git a/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py b/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py new file mode 100644 index 000000000..594a8d7aa --- /dev/null +++ b/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py @@ -0,0 +1,41 @@ +from typing import Any, List, Optional, Sequence, Tuple + +from erniebot_agent.agents import FunctionalAgent +from erniebot_agent.agents.schema import AgentAction, AgentFile, File, AgentResponse, AgentStep +from erniebot_agent.memory.messages import HumanMessage, Message +from erniebot_agent.tools.base import BaseTool + +class PromptAgent(FunctionalAgent): + + async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: + actions_taken: List[AgentAction] = [] + + chat_history: List[Message] = [] + + next_step_input = HumanMessage(content=prompt) + curr_step_output = await self._step( + next_step_input, chat_history, actions_taken + ) + return curr_step_output + + async def _step( + self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None + ) -> Tuple[AgentStep, List[Message]]: + new_messages: List[Message] = [] + input_messages = self.memory.get_messages() + chat_history + if selected_tool is not None: + tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}} + llm_resp = await self.run_llm( + messages=input_messages, + functions=[selected_tool.function_call_schema()], # only regist one tool + tool_choice=tool_choice, + ) + else: + llm_resp = await self.run_llm(messages=input_messages) + + output_message = llm_resp.message # AIMessage + new_messages.append(output_message) + if output_message.function_call is not None: + return True + else: + return False From 2c395367fff59e998bf4920f96a9be7b19598a3b Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 17 Jan 2024 09:29:55 +0000 Subject: [PATCH 31/43] Update prompt agent --- .../agents/function_agent_with_retrieval.py | 2 +- .../src/erniebot_agent/agents/prompt_agent.py | 41 ++++++--- .../tools/langchain_retrieval_tool.py | 90 +++++++++++++++++++ 3 files changed, 118 insertions(+), 15 deletions(-) create mode 100644 erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index e81570f9c..cb288ad8d 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -440,7 +440,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) rewrite_prompt = "背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" - return super()._run(rewrite_prompt, files) + return await super()._run(rewrite_prompt, files) else: _logger.info( f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}" diff --git a/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py b/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py index 594a8d7aa..648f2b42d 100644 --- a/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py @@ -1,26 +1,26 @@ -from typing import Any, List, Optional, Sequence, Tuple +import json +from typing import List, Optional, Sequence -from erniebot_agent.agents import FunctionalAgent -from erniebot_agent.agents.schema import AgentAction, AgentFile, File, AgentResponse, AgentStep +from erniebot_agent.agents.agent import Agent +from erniebot_agent.agents.schema import AgentResponse, AgentStep, File from erniebot_agent.memory.messages import HumanMessage, Message from erniebot_agent.tools.base import BaseTool -class PromptAgent(FunctionalAgent): +class PromptAgent(Agent): async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: - actions_taken: List[AgentAction] = [] - chat_history: List[Message] = [] - - next_step_input = HumanMessage(content=prompt) - curr_step_output = await self._step( - next_step_input, chat_history, actions_taken + steps_taken: List[AgentStep] = [] + run_input = await HumanMessage.create_with_files( + prompt, files or [], include_file_urls=self.file_needs_url ) - return curr_step_output + chat_history.append(run_input) + msg = await self._step(chat_history) + text = json.dumps({"msg": msg}, ensure_ascii=False) + response = self._create_stopped_response(chat_history, steps_taken, message=text) + return response - async def _step( - self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None - ) -> Tuple[AgentStep, List[Message]]: + async def _step(self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None) -> bool: new_messages: List[Message] = [] input_messages = self.memory.get_messages() + chat_history if selected_tool is not None: @@ -39,3 +39,16 @@ async def _step( return True else: return False + + def _create_stopped_response( + self, + chat_history: List[Message], + steps: List[AgentStep], + message: str, + ) -> AgentResponse: + return AgentResponse( + text=message, + chat_history=chat_history, + steps=steps, + status="STOPPED", + ) diff --git a/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py new file mode 100644 index 000000000..9c65be22c --- /dev/null +++ b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Type + +from pydantic import Field + +from erniebot_agent.memory.messages import AIMessage, HumanMessage +from erniebot_agent.tools.schema import ToolParameterView + +from .base import Tool + + +class LangChainRetrievalToolInputView(ToolParameterView): + query: str = Field(description="查询语句") + top_k: int = Field(description="返回结果数量") + + +class SearchResponseDocument(ToolParameterView): + title: str = Field(description="检索结果的标题") + document: str = Field(description="检索结果的内容") + + +class LangChainRetrievalToolOutputView(ToolParameterView): + documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落") + + +class LangChainRetrievalTool(Tool): + description: str = "在知识库中检索与用户输入query相关的段落" + input_type: Type[ToolParameterView] = LangChainRetrievalToolInputView + ouptut_type: Type[ToolParameterView] = LangChainRetrievalToolOutputView + + def __init__( + self, + name, + description, + db, + threshold: float = 0.0, + input_type=None, + output_type=None, + examples=None, + return_meta_data: bool = False, + ) -> None: + super().__init__() + self.name = name + self.db = db + self.description = description + self.return_meta_data = return_meta_data + self.few_shot_examples = [] + if input_type is not None: + self.input_type = input_type + if output_type is not None: + self.ouptut_type = output_type + if examples is not None: + self.few_shot_examples = examples + self.threshold = threshold + + async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None): + documents = self.db.similarity_search_with_relevance_scores(query, top_k) + docs = [] + for doc, score in documents: + if score > self.threshold: + new_doc = {"document": doc.page_content} + if self.return_meta_data: + new_doc["meta"] = doc.metadata + if "source" in doc.metadata: + new_doc["title"] = doc.metadata["source"] + + docs.append(new_doc) + + return {"documents": docs} + + @property + def examples( + self, + ) -> List[Any]: + few_shot_objects: List[Any] = [] + for item in self.few_shot_examples: + few_shot_objects.append(HumanMessage(item["user"])) + few_shot_objects.append( + AIMessage( + "", + function_call={ + "name": self.tool_name, + "thoughts": item["thoughts"], + "arguments": item["arguments"], + }, + ) + ) + + return few_shot_objects From 4534b1dbd7c760497645f795a993a132f6b759f2 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 18 Jan 2024 06:24:53 +0000 Subject: [PATCH 32/43] remove prompt agent --- .../examples/automatic_prompt_engineer.py | 125 ------------------ erniebot-agent/examples/text_summarization.py | 61 --------- erniebot-agent/examples/utils.py | 120 ----------------- .../src/erniebot_agent/agents/prompt_agent.py | 54 -------- .../src/erniebot/resources/chat_completion.py | 1 + 5 files changed, 1 insertion(+), 360 deletions(-) delete mode 100644 erniebot-agent/examples/automatic_prompt_engineer.py delete mode 100644 erniebot-agent/examples/text_summarization.py delete mode 100644 erniebot-agent/examples/utils.py delete mode 100644 erniebot-agent/src/erniebot_agent/agents/prompt_agent.py diff --git a/erniebot-agent/examples/automatic_prompt_engineer.py b/erniebot-agent/examples/automatic_prompt_engineer.py deleted file mode 100644 index 685c197c0..000000000 --- a/erniebot-agent/examples/automatic_prompt_engineer.py +++ /dev/null @@ -1,125 +0,0 @@ -import argparse -import asyncio - -from erniebot_agent.agents.prompt_agent import PromptAgent -from erniebot_agent.chat_models import ERNIEBot -from erniebot_agent.memory import WholeMemory -from erniebot_agent.tools.openai_search_tool import OpenAISearchTool -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.vectorstores import FAISS -from prettytable import PrettyTable -from utils import ( - create_description, - create_keywords, - create_questions, - erniebot_chat, - read_data, -) - -import erniebot - -# yapf: disable -parser = argparse.ArgumentParser() -parser.add_argument("--api_type", default=None, type=str, help="The API Key.") -parser.add_argument("--access_token", default=None, type=str, help="The secret key.") -parser.add_argument("--summarization_path", default='data/data.jsonl', type=str, help="The output path.") -parser.add_argument("--number_of_prompts", default=3, type=int, help="The number of tool descriptions.") -parser.add_argument("--num_questions", default=3, type=int, help="The number of few shot questions.") -parser.add_argument("--num_keywords", default=-1, type=int, help="The number of few shot questions.") -args = parser.parse_args() -# yapf: enable - - -def generate_candidate_prompts(description, number_of_prompts): - prompts = [] - for i in range(number_of_prompts): - messages = [create_description(description)] - results = erniebot_chat( - messages, model="ernie-bot-4", api_type=args.api_type, access_token=args.access_token - )["result"] - prompts.append(results) - return prompts - - -def generate_candidate_questions( - description, num_questions: int = -1, num_keywords: int = -1, temperature=1e-10 -): - if num_questions > 0: - messages = [create_questions(description, num_questions=num_questions)] - elif num_keywords > 0: - messages = [create_keywords(description, num_keywords=num_keywords)] - - results = erniebot_chat( - messages, api_type=args.api_type, access_token=args.access_token, temperature=temperature - )["result"] - return results - - -if __name__ == "__main__": - erniebot.api_type = args.api_type - erniebot.access_token = args.access_token - - embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - faiss = FAISS.load_local("城市管理执法办法", embeddings) - - list_data = read_data(args.summarization_path) - doc = list_data[0] - tool_descriptions = generate_candidate_prompts(doc["abstract"], number_of_prompts=args.number_of_prompts) - tool_descriptions = list(set(tool_descriptions)) - print(tool_descriptions) - - questions = generate_candidate_questions(doc["abstract"], num_questions=args.num_questions).split("\n") - - if args.num_keywords > 0: - keywords = generate_candidate_questions(doc["abstract"], num_keywords=args.num_keywords).split("\n") - questions += keywords - - prompts = tool_descriptions - prompt_results = {prompt: {"correct": 0.0, "total": 0.0} for prompt in prompts} - - # Initialize the table - table = PrettyTable() - table_field_names = ["Prompt"] + [ - f"question {i+1}-{j+1}" for j, prompt in enumerate(questions) for i in range(questions.count(prompt)) - ] - table.field_names = table_field_names - - # Wrap the text in the "Prompt" column - table.max_width["Prompt"] = 100 - - llm = ERNIEBot(model="ernie-bot") - best_prompt = None - best_percentage = 0.0 - for i, tool_description in enumerate(tool_descriptions): - openai_city_management = OpenAISearchTool( - name="city_administrative_law_enforcement", - description=tool_description, - db=faiss, - threshold=0.1, - ) - row = [tool_description] - resps = [] - for query in questions: - agent = PromptAgent(memory=WholeMemory(), llm=llm, tools=[openai_city_management]) - response = asyncio.run(agent.async_run(query)) - resps.append(response) - if response is True: - prompt_results[tool_description]["correct"] += 1 - row.append("✅") - else: - row.append("❌") - prompt_results[tool_description]["total"] += 1 - table.add_row(row) - - print(f"生成的问题如下:{questions}") - print(table) - for i, prompt in enumerate(prompts): - correct = prompt_results[prompt]["correct"] - total = prompt_results[prompt]["total"] - percentage = (correct / total) * 100 - print(f"Prompt {i+1} got {percentage:.2f}% correct.") - if percentage > best_percentage: - best_percentage = percentage - best_prompt = tool_description - - print(f"The best prompt was '{best_prompt}' with a correctness of {best_percentage:.2f}%.") diff --git a/erniebot-agent/examples/text_summarization.py b/erniebot-agent/examples/text_summarization.py deleted file mode 100644 index 9f4024707..000000000 --- a/erniebot-agent/examples/text_summarization.py +++ /dev/null @@ -1,61 +0,0 @@ -import argparse -import os - -import jsonlines -from utils import create_abstract, erniebot_chat, read_data, split_text - -# yapf: disable -parser = argparse.ArgumentParser() -parser.add_argument("--api_type", default=None, type=str, help="The API Key.") -parser.add_argument("--access_token", default=None, type=str, help="The secret key.") -parser.add_argument("--data_path", default='data/json_data.jsonl', type=str, help="The data path.") -parser.add_argument("--output_path", default='data/finance_abstract', type=str, help="The output path.") -parser.add_argument('--chatbot_type', choices=['erniebot'], default="erniebot", - help="The chatbot model types") -args = parser.parse_args() -# yapf: enable - - -def summarize_text(text: str): - summaries = [] - - chunks = list(split_text(text, max_length=4096)) - print(f"Summarizing text with total chunks: {len(chunks)}") - for i, chunk in enumerate(chunks): - messages = [create_abstract(chunk)] - summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token).rbody[ - "result" - ] - print(summary) - summaries.append(summary) - - combined_summary = "\n".join(summaries) - combined_summary = combined_summary[:7000] - messages = [create_abstract(combined_summary)] - - final_summary = erniebot_chat(messages, api_type=args.api_type, access_token=args.access_token).rbody[ - "result" - ] - print("Final summary length: ", len(final_summary)) - print(final_summary) - return final_summary - - -def generate_summary_jsonl(): - os.makedirs(args.output_path, exist_ok=True) - list_data = read_data(args.data_path) - for md_file in list_data: - markdown_text = md_file["content"] - summary = summarize_text(markdown_text) - md_file["abstract"] = summary - - output_json = f"{args.output_path}/data.jsonl" - with jsonlines.open(output_json, "w") as f: - for item in list_data: - f.write(item) - return output_json - - -if __name__ == "__main__": - # text summarization - generate_summary_jsonl() diff --git a/erniebot-agent/examples/utils.py b/erniebot-agent/examples/utils.py deleted file mode 100644 index 25a373c56..000000000 --- a/erniebot-agent/examples/utils.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Dict, Generator - -import jsonlines - -import erniebot - - -def read_data(json_path): - list_data = [] - with jsonlines.open(json_path, "r") as f: - for item in f: - list_data.append(item) - return list_data - - -def create_abstract(chunk: str) -> Dict[str, str]: - """Create a message for the chat completion - - Args: - chunk (str): The chunk of text to summarize - question (str): The question to answer - - Returns: - Dict[str, str]: The message to send to the chat completion - """ - return { - "role": "user", - "content": f"""{chunk},请用中文对上述文章进行总结,总结需要有概括性,不允许输出与文章内容无关的信息,字数控制在500字以内。""", - } - - -def create_questions(chunk: str, num_questions: int = 5): - return { - "role": "user", - "content": f"""{chunk},请根据上面的摘要,生成{num_questions}个问题,问题内容和形式要多样化,口语化,不允许重复,分条列举出来.""", - } - - -def create_keywords(chunk: str, num_keywords: int = 3): - return { - "role": "user", - "content": f"""{chunk},请根据上面的摘要,抽取{num_keywords}个关键字或者简短关键句,分条列举出来。 - 要求:只需要输出关键字或者简短的关键句,不需要输出其它的内容.""", - } - - -def create_description(chunk: str) -> Dict[str, str]: - """Create a message for the chat completion - - Args: - chunk (str): The chunk of text to summarize - question (str): The question to answer - - Returns: - Dict[str, str]: The message to send to the chat completion - """ - return { - "role": "user", - "content": f"""{chunk},请根据上面的摘要,生成一个简短的描述,不超过40字.""", - } - - -def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]: - """Split text into chunks of a maximum length - - Args: - text (str): The text to split - max_length (int, optional): The maximum length of each chunk. Defaults to 8192. - - Yields: - str: The next chunk of text - - Raises: - ValueError: If the text is longer than the maximum length - """ - paragraphs = text.split("\n") - current_length = 0 - current_chunk = [] - - for paragraph in paragraphs: - if current_length + len(paragraph) + 1 <= max_length: - current_chunk.append(paragraph) - current_length += len(paragraph) + 1 - else: - yield "\n".join(current_chunk) - current_chunk = [paragraph] - current_length = len(paragraph) + 1 - - if current_chunk: - yield "\n".join(current_chunk) - - -def erniebot_chat( - messages, model="ernie-bot", api_type="aistudio", access_token=None, functions=None, **kwargs -): - """ - Args: - messages: dict or list, 输入的消息(message) - model: str, 模型名称 - api_type: str, 接口类型,可选值包括 'aistudio' 和 'qianfan' - access_token: str, 访问令牌(access token) - functions: list, 函数列表 - kwargs: 其他参数 - - Returns: - dict or list, 返回聊天结果 - """ - _config = dict( - api_type=api_type, - access_token=access_token, - ) - if functions is None: - resp_stream = erniebot.ChatCompletion.create( - _config_=_config, model=model, messages=messages, **kwargs - ) - else: - resp_stream = erniebot.ChatCompletion.create( - _config_=_config, model=model, messages=messages, **kwargs, functions=functions - ) - return resp_stream diff --git a/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py b/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py deleted file mode 100644 index 648f2b42d..000000000 --- a/erniebot-agent/src/erniebot_agent/agents/prompt_agent.py +++ /dev/null @@ -1,54 +0,0 @@ -import json -from typing import List, Optional, Sequence - -from erniebot_agent.agents.agent import Agent -from erniebot_agent.agents.schema import AgentResponse, AgentStep, File -from erniebot_agent.memory.messages import HumanMessage, Message -from erniebot_agent.tools.base import BaseTool - - -class PromptAgent(Agent): - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: - chat_history: List[Message] = [] - steps_taken: List[AgentStep] = [] - run_input = await HumanMessage.create_with_files( - prompt, files or [], include_file_urls=self.file_needs_url - ) - chat_history.append(run_input) - msg = await self._step(chat_history) - text = json.dumps({"msg": msg}, ensure_ascii=False) - response = self._create_stopped_response(chat_history, steps_taken, message=text) - return response - - async def _step(self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None) -> bool: - new_messages: List[Message] = [] - input_messages = self.memory.get_messages() + chat_history - if selected_tool is not None: - tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}} - llm_resp = await self.run_llm( - messages=input_messages, - functions=[selected_tool.function_call_schema()], # only regist one tool - tool_choice=tool_choice, - ) - else: - llm_resp = await self.run_llm(messages=input_messages) - - output_message = llm_resp.message # AIMessage - new_messages.append(output_message) - if output_message.function_call is not None: - return True - else: - return False - - def _create_stopped_response( - self, - chat_history: List[Message], - steps: List[AgentStep], - message: str, - ) -> AgentResponse: - return AgentResponse( - text=message, - chat_history=chat_history, - steps=steps, - status="STOPPED", - ) diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index a4ef00c7f..e626b690c 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -515,6 +515,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: # stream stream = kwargs.get("stream", False) + return RequestWithStream( method="POST", path=path, From ce74d548d10e2b9684bc2ad9500e691b508a0b0f Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 18 Jan 2024 06:27:26 +0000 Subject: [PATCH 33/43] delete faiss_util.py --- erniebot-agent/examples/faiss_util.py | 98 --------------------------- 1 file changed, 98 deletions(-) delete mode 100644 erniebot-agent/examples/faiss_util.py diff --git a/erniebot-agent/examples/faiss_util.py b/erniebot-agent/examples/faiss_util.py deleted file mode 100644 index 28f0e40d2..000000000 --- a/erniebot-agent/examples/faiss_util.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import os - -from langchain.docstore.document import Document -from langchain.document_loaders import UnstructuredFileLoader -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.text_splitter import SpacyTextSplitter -from langchain.vectorstores import FAISS -from utils import read_data - -embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - - -def get_args(): - # yapf: disable - parser = argparse.ArgumentParser() - parser.add_argument('--faiss_name', default="faiss_index", help="The faiss index") - parser.add_argument('--summary_name', default="faiss_index", help="The summary text index") - parser.add_argument('--fulltext_name', default="faiss_index", help="The full text index") - parser.add_argument('--file_path', default="data/output.jsonl", help="The data output path") - parser.add_argument('--indexing_type', choices=['summary_fulltext', 'common'], - default="common", help="The indexing types") - args = parser.parse_args() - # yapf: enable - return args - - -class SemanticSearch: - def __init__(self, faiss_name, file_path=None) -> None: - self.faiss_name = faiss_name - self.file_path = file_path - self.vector_db = self.init_db() - - def init_db( - self, - ): - if os.path.exists(self.faiss_name): - faiss = FAISS.load_local(self.faiss_name, embeddings) - else: - loader = UnstructuredFileLoader(self.file_path) - documents = loader.load() - text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) - docs = text_splitter.split_documents(documents) - faiss = FAISS.from_documents(docs, embeddings) - faiss.save_local(self.faiss_name) - return faiss - - def search(self, query, top_k=4): - return self.vector_db.similarity_search(query, k=top_k) - - -class RecursiveDocuments: - def __init__(self, summary_name, fulltext_name, file_path=None) -> None: - self.summary_name = summary_name - self.fulltext_name = fulltext_name - self.file_path = file_path - self.vector_db = self.init_db() - - def init_db( - self, - ): - if os.path.exists(self.summary_name) and os.path.exists(self.fulltext_name): - summary_faiss = FAISS.load_local(self.summary_name, embeddings) - fulltext_faiss = FAISS.load_local(self.fulltext_name, embeddings) - else: - list_data = read_data(self.file_path) - doc_summary = [] - doc_fulltext = [] - text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) - for item in list_data: - full_texts = Document(page_content=item["content"]) - - abstract = Document(page_content=item["abstract"]) - docs = text_splitter.split_documents([full_texts]) - doc_fulltext.extend(docs) - - doc_summary.append(abstract) - - summary_faiss = FAISS.from_documents(doc_summary, embeddings) - summary_faiss.save_local(self.summary_name) - - fulltext_faiss = FAISS.from_documents(doc_fulltext, embeddings) - fulltext_faiss.save_local(self.fulltext_name) - return summary_faiss - - def search(self, query, top_k=4): - return self.vector_db.similarity_search(query, k=top_k) - - -if __name__ == "__main__": - query = "GPT-3是怎么训练得到的?" - args = get_args() - if args.indexing_type == "common": - faiss_search = SemanticSearch(args.faiss_name, args.file_path) - docs = faiss_search.search(query) - print(docs) - else: - recursive_search = RecursiveDocuments(args.summary_name, args.fulltext_name, args.file_path) From f8b0450952b25a66114f42ec998f7e4c42a0593b Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 18 Jan 2024 06:29:54 +0000 Subject: [PATCH 34/43] Update --- erniebot/src/erniebot/resources/chat_completion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index e626b690c..95263e1f1 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -515,7 +515,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: # stream stream = kwargs.get("stream", False) - + return RequestWithStream( method="POST", path=path, @@ -573,4 +573,4 @@ def to_message(self) -> Dict[str, Any]: message["function_call"] = self.function_call else: message["content"] = self.result - return message + return message \ No newline at end of file From c104795aaa9605a513c8a0ad3b74bd82edebd33c Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 18 Jan 2024 06:32:12 +0000 Subject: [PATCH 35/43] delete knowledge_tools_example --- .../examples/knowledge_tools_example.py | 205 ------------------ 1 file changed, 205 deletions(-) delete mode 100644 erniebot-agent/examples/knowledge_tools_example.py diff --git a/erniebot-agent/examples/knowledge_tools_example.py b/erniebot-agent/examples/knowledge_tools_example.py deleted file mode 100644 index d6e90f7f3..000000000 --- a/erniebot-agent/examples/knowledge_tools_example.py +++ /dev/null @@ -1,205 +0,0 @@ -import argparse -import asyncio -from typing import Dict, List, Type - -from erniebot_agent.agents import FunctionalAgentWithQueryPlanning -from erniebot_agent.chat_models import ERNIEBot -from erniebot_agent.memory import WholeMemory -from erniebot_agent.messages import AIMessage, HumanMessage, Message -from erniebot_agent.retrieval import BaizhongSearch -from erniebot_agent.tools.baizhong_tool import BaizhongSearchTool -from erniebot_agent.tools.base import RemoteToolkit, Tool -from erniebot_agent.tools.openai_search_tool import OpenAISearchTool -from erniebot_agent.tools.schema import ToolParameterView -from langchain.docstore.document import Document -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.vectorstores import FAISS -from pydantic import Field - -import erniebot - -parser = argparse.ArgumentParser() -parser.add_argument("--base_url", type=str, help="The Aurora serving path.") -parser.add_argument("--data_path", default="construction_regulations", type=str, help="The data path.") -parser.add_argument( - "--access_token", default="ai_studio_access_token", type=str, help="The aistudio access token." -) -parser.add_argument("--api_type", default="qianfan", type=str, help="The aistudio access token.") -parser.add_argument("--api_key", default="", type=str, help="The API Key.") -parser.add_argument("--secret_key", default="", type=str, help="The secret key.") -parser.add_argument("--indexing", action="store_true", help="The indexing step.") -parser.add_argument("--project_id", default=-1, type=int, help="The API Key.") -parser.add_argument( - "--retrieval_type", - choices=["summary_fulltext_tools", "knowledge_tools"], - default="knowledge_tools", - help="Retrieval type, default to rag.", -) -parser.add_argument( - "--search_engine", - choices=["baizhong", "openai"], - default="baizhong", - help="search_engine.", -) -args = parser.parse_args() - - -class NotesToolInputView(ToolParameterView): - draft: str = Field(description="草稿文本") - - -class NotesToolOutputView(ToolParameterView): - draft_results: str = Field(description="草稿文本结果") - - -class NotesTool(Tool): - description: str = "笔记本,用于记录和保存信息的笔记本工具" - input_type: Type[ToolParameterView] = NotesToolInputView - ouptut_type: Type[ToolParameterView] = NotesToolOutputView - - async def __call__(self, draft: str) -> Dict[str, str]: - # TODO: save draft to database - return {"draft_results": "草稿在笔记本中保存成功"} - - @property - def examples(self) -> List[Message]: - return [ - HumanMessage("OpenAI管理层变更会带来哪些影响?并请把搜索的内容添加到笔记本中"), - AIMessage( - "", - function_call={ - "name": self.tool_name, - "thoughts": f"用户想保存笔记,我可以使用{self.tool_name}工具来保存,其中`draft`字段的内容为:'搜索的草稿'。", - "arguments": '{"draft": "搜索的草稿"}', - }, - ), - ] - - -if __name__ == "__main__": - erniebot.api_type = args.api_type - erniebot.access_token = args.access_token - - llm = ERNIEBot(model="ernie-bot", api_type="custom") - if args.search_engine == "baizhong": - baizhong_db = BaizhongSearch( - base_url=args.base_url, - project_name="construct_assistant2", - remark="construction assistant test dataset", - project_id=args.project_id if args.project_id != -1 else None, - ) - print(baizhong_db.project_id) - # 建筑规范数据集 - city_management = BaizhongSearchTool( - name="city_administrative_law_enforcement", - description="提供城市管理执法办法相关的信息", - db=baizhong_db, - threshold=0.1, - ) - city_design = BaizhongSearchTool( - name="city_design_management", description="提供城市设计管理办法的信息", db=baizhong_db, threshold=0.1 - ) - city_lighting = BaizhongSearchTool( - name="city_lighting", description="提供关于城市照明管理规定的信息", db=baizhong_db, threshold=0.1 - ) - - tool_retriever = BaizhongSearchTool( - name="tool_retriever", description="用于检索与query相关的tools列表", db=baizhong_db, threshold=0.1 - ) - elif args.search_engine == "openai" and args.retrieval_type == "knowledge_tools": - embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - faiss = FAISS.load_local("城市管理执法办法", embeddings) - openai_city_management = OpenAISearchTool( - name="city_administrative_law_enforcement", - description="提供城市管理执法办法相关的信息", - db=faiss, - threshold=0.1, - ) - faiss = FAISS.load_local("城市设计管理办法", embeddings) - openai_city_design = OpenAISearchTool( - name="city_design_management", description="提供城市设计管理办法的信息", db=faiss, threshold=0.1 - ) - faiss = FAISS.load_local("城市照明管理规定", embeddings) - openai_city_lighting = OpenAISearchTool( - name="city_lighting", description="提供关于城市照明管理规定的信息", db=faiss, threshold=0.1 - ) - # TODO(wugaoshewng) 加入APE后,变成knowledge_base_toolkit - # faiss = FAISS.load_local("tool_retriever", embeddings) - tool_map = { - "city_administrative_law_enforcement": openai_city_management, - "city_design_management": openai_city_design, - "city_lighting": openai_city_lighting, - } - docs = [] - for tool in tool_map.values(): - doc = Document(page_content=tool.description, metadata={"tool_name": tool.name}) - docs.append(doc) - - faiss_tool = FAISS.from_documents(docs, embeddings) - tool_retriever = OpenAISearchTool( # type: ignore - name="tool_retriever", - description="用于检索与query相关的tools列表", - db=faiss_tool, - threshold=0.1, - return_meta_data=True, - ) - fulltext_faiss = FAISS.load_local("fulltext", embeddings) - vector_tool = OpenAISearchTool( - name="fulltext_search", - description="使用这个工具检索特定的上下文,以回答有关建筑规范具体的问题", - db=fulltext_faiss, - threshold=0.1, - ) - elif args.search_engine == "openai" and args.retrieval_type == "summary_fulltext_tools": - embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - summary_faiss = FAISS.load_local("summary", embeddings) - summary_tool = OpenAISearchTool( - name="text_summary_search", description="使用这个工具总结与建筑规范相关的问题", db=summary_faiss, threshold=0.1 - ) - fulltext_faiss = FAISS.load_local("fulltext", embeddings) - vector_tool = OpenAISearchTool( - name="fulltext_search", - description="使用这个工具检索特定的上下文,以回答有关建筑规范具体的问题", - db=fulltext_faiss, - threshold=0.1, - ) - - queries = [ - "量化交易", - "OpenAI管理层变更会带来哪些影响?", - "城市景观照明中有过度照明的规定是什么?", - "城市景观照明中有过度照明的规定是什么?并把搜索的内容添加到笔记本中", - "请比较一下城市设计管理和照明管理规定的区别?", - "这几篇文档主要内容是什么?", - "今天天气怎么样?", - "abcabc", - ] - toolkit = RemoteToolkit.from_openapi_file("../tests/fixtures/openapi.yaml") - for query in queries: - memory = WholeMemory() - if args.retrieval_type == "summary_fulltext_tools": - agent = FunctionalAgentWithQueryPlanning( # type: ignore - llm=llm, - top_k=3, - tools=[summary_tool, vector_tool], - memory=memory, - ) - elif args.retrieval_type == "knowledge_tools": - tool_results = asyncio.run(tool_retriever(query))["documents"] - selected_tools = [tool_map[item["meta"]["tool_name"]] for item in tool_results] - agent = FunctionalAgentWithQueryPlanning( # type: ignore - llm=llm, - top_k=3, - knowledge_base=vector_tool, - # tools=toolkit.get_tools() + [city_management, city_design, city_lighting], - # tools = [NotesTool(),city_management, city_design, city_lighting], - tools=[NotesTool()] + selected_tools, - memory=memory, - ) - - try: - response = asyncio.run(agent.async_run(query)) - print(f"query: {query}") - print(f"agent response: {response}") - except Exception as e: - print(e) From c85c7c4cd09776c389d9bee296abf3e52d13c243 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 18 Jan 2024 06:39:50 +0000 Subject: [PATCH 36/43] Update format --- erniebot-agent/src/erniebot_agent/agents/__init__.py | 4 ++-- .../agents/function_agent_with_retrieval.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/__init__.py b/erniebot-agent/src/erniebot_agent/agents/__init__.py index 20909f669..85a7aadfc 100644 --- a/erniebot-agent/src/erniebot_agent/agents/__init__.py +++ b/erniebot-agent/src/erniebot_agent/agents/__init__.py @@ -15,9 +15,9 @@ from erniebot_agent.agents.agent import Agent from erniebot_agent.agents.function_agent import FunctionAgent from erniebot_agent.agents.function_agent_with_retrieval import ( - ContextAugmentedFunctionalAgent, + ContextAugmentedFunctionAgent, FunctionAgentWithRetrieval, FunctionAgentWithRetrievalScoreTool, FunctionAgentWithRetrievalTool, - FunctionalAgentWithQueryPlanning, + FunctionAgentWithQueryPlanning, ) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index cb288ad8d..925bcebf8 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -363,7 +363,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age return response else: _logger.info( - f"Irrelevant retrieval results. Fallbacking to FunctionalAgent for the query: {prompt}" + f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}" ) return await super()._run(prompt, files) @@ -378,7 +378,7 @@ async def _maybe_retrieval( return results -class ContextAugmentedFunctionalAgent(FunctionAgent): +class ContextAugmentedFunctionAgent(FunctionAgent): def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) self.knowledge_base = knowledge_base @@ -471,7 +471,7 @@ async def _maybe_retrieval( 请根据以上检索结果回答检索语句的问题""" -class FunctionalAgentWithQueryPlanning(FunctionAgent): +class FunctionAgentWithQueryPlanning(FunctionAgent): def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) self.top_k = top_k @@ -557,10 +557,10 @@ def _parse_results(self, results): left_index = results.find("{") right_index = results.rfind("}") if left_index == -1 or right_index == -1: - # if invalid json, use Functional Agent + # if invalid json, use Function Agent return {"is_relevant": False} try: return json.loads(results[left_index : right_index + 1]) except Exception: - # if invalid json, use Functional Agent + # if invalid json, use Function Agent return {"is_relevant": False} From 67400561ba0c0d4abf6ac61989b02d232d57953f Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 19 Jan 2024 03:25:29 +0000 Subject: [PATCH 37/43] remove FunctionAgentWithQueryPlanning --- .../agents/function_agent_with_retrieval.py | 166 ++++-------------- .../tools/langchain_retrieval_tool.py | 6 +- 2 files changed, 36 insertions(+), 136 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index 925bcebf8..c95d8202d 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -378,6 +378,10 @@ async def _maybe_retrieval( return results +REWRITE_PROMPT = """请对下面的问题进行子问题提取,用于检索相关的信息帮助回答原问题,要求: +1.严格按照【JSON格式】的形式输出:{'sub query1':'具体子问题1','sub_query2':'具体子问题2'} +原问题:{{query}} 提取的子问题:""" + class ContextAugmentedFunctionAgent(FunctionAgent): def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) @@ -386,10 +390,12 @@ def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: fl self.threshold = threshold self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) self.search_tool = KnowledgeBaseTool() + self.query_rewrite = PromptTemplate(REWRITE_PROMPT, input_variables=["query"]) async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: - results = await self._maybe_retrieval(prompt) - if len(results["documents"]) > 0: + # Rewrite queries for retrieval + results = await self._query_rewrite(prompt) + if len(results) > 0: # RAG chat_history: List[Message] = [] steps_taken: List[AgentStep] = [] @@ -398,38 +404,23 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age await self._callback_manager.on_tool_start( agent=self, tool=self.search_tool, input_args=tool_args ) + # Generate Answer step_input = HumanMessage( - content=self.rag_prompt.format(query=prompt, documents=results["documents"]) + content=self.rag_prompt.format(query=prompt, documents=results) ) fake_chat_history: List[Message] = [] fake_chat_history.append(step_input) llm_resp = await self.run_llm(messages=fake_chat_history) - - # Get RAG results output_message = llm_resp.message - outputs = [] - for item in results["documents"]: - outputs.append( - { - "id": item["id"], - "title": item["title"], - "document": item["content_se"], - } - ) - - # 会有无限循环调用工具的问题 + # Context Augmented query + rewrite_prompt = f"背景信息为:{output_message.content} \n 给定背景信息,而不是先验知识,选择相应的工具回答或者根据背景信息直接回答问题:{prompt}" next_step_input = HumanMessage( - content=f"背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" + content=rewrite_prompt ) chat_history.append(next_step_input) - # Knowledge Retrieval Tool - # action = ToolAction(tool_name=self.search_tool.tool_name, tool_args=tool_args) # return response - tool_ret_json = json.dumps({"documents": outputs}, ensure_ascii=False) - # next_step_input = FunctionMessage(name=action.tool_name, content=tool_ret_json) - - tool_resp = ToolResponse(json=tool_ret_json, input_files=[], output_files=[]) + tool_resp = ToolResponse(json=json.dumps(results, ensure_ascii=False), input_files=[], output_files=[]) steps_taken.append( ToolStep( info=ToolInfo(tool_name=self.search_tool.tool_name, tool_args=tool_args), @@ -439,7 +430,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) ) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) - rewrite_prompt = "背景信息为:{output_message.content} \n 要求:选择相应的工具回答或者根据背景信息直接回答:{prompt}" + return await super()._run(rewrite_prompt, files) else: _logger.info( @@ -447,120 +438,29 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) return await super()._run(prompt, files) + async def _query_rewrite(self, prompt): + # Rewrite queries for retrieval + step_input = HumanMessage(content=self.query_rewrite.format(query=prompt)) + fake_chat_history: List[Message] = [] + fake_chat_history.append(step_input) + llm_resp = await self.run_llm(messages=fake_chat_history, functions=None) + output_message = llm_resp.message + queries = self._parse_results(output_message.content) + results = [] + for sub_query in list(queries.values()): + sub_results = await self._maybe_retrieval(sub_query) + results.extend(sub_results) + return results + async def _maybe_retrieval( self, step_input, ): - documents = self.knowledge_base.search(step_input, top_k=self.top_k, filters=None) - documents = [item for item in documents if item["score"] > self.threshold] - results = {} - results["documents"] = documents - return results - - -QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,每个子问题必须足够简单,要求: -1.严格按照【JSON格式】的形式输出:{'子问题1':'具体子问题1','子问题2':'具体子问题2'} -问题:{{prompt}} 子问题:""" - - -OPENAI_RAG_PROMPT = """检索结果: -{% for doc in documents %} - 第{{loop.index}}个段落: {{doc['document']}} -{% endfor %} -检索语句: {{query}} -请根据以上检索结果回答检索语句的问题""" - - -class FunctionAgentWithQueryPlanning(FunctionAgent): - def __init__(self, knowledge_base, top_k: int = 2, threshold: float = 0.1, **kwargs): - super().__init__(**kwargs) - self.top_k = top_k - self.threshold = threshold - self.query_transform = PromptTemplate(QUERY_DECOMPOSITION, input_variables=["prompt"]) - self.knowledge_base = knowledge_base - self.rag_prompt = PromptTemplate(OPENAI_RAG_PROMPT, input_variables=["documents", "query"]) - - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: - chat_history: List[Message] = [] - steps_taken: List[AgentStep] = [] - - # 会有无限循环调用工具的问题 - # next_step_input = HumanMessage( - # content=f"请选择合适的工具来回答:{prompt},如果需要的话,可以对把问题分解成子问题,然后每个子问题选择合适的工具回答。" - # ) - run_input = await HumanMessage.create_with_files( - prompt, files or [], include_file_urls=self.file_needs_url - ) - chat_history.append(run_input) - num_steps_taken = 0 - while num_steps_taken < self.max_steps: - curr_step, new_messages = await self._step(chat_history) - chat_history.extend(new_messages) - if isinstance(curr_step, ToolStep): - steps_taken.append(curr_step) - - elif isinstance(curr_step, PluginStep): - steps_taken.append(curr_step) - # 预留 调用了Plugin之后不结束的接口 - - # 此处为调用了Plugin之后直接结束的Plugin - curr_step = DEFAULT_FINISH_STEP - - if isinstance(curr_step, EndStep): - response = self._create_finished_response(chat_history, steps_taken, curr_step) - self.memory.add_message(chat_history[0]) - self.memory.add_message(chat_history[-1]) - return response - num_steps_taken += 1 - # TODO(wugaosheng): Add manual planning and execute - return await self.plan_and_execute(prompt, steps_taken, curr_step) - - async def plan_and_execute(self, prompt, steps_taken: List[AgentStep], curr_step: AgentStep): - step_input = HumanMessage(content=self.query_transform.format(prompt=prompt)) - fake_chat_history: List[Message] = [step_input] - llm_resp = await self.run_llm( - messages=fake_chat_history, - ) - output_message = llm_resp.message - - json_results = self._parse_results(output_message.content) - sub_queries = json_results.values() - retrieval_results = [] - duplicates = set() - for query in sub_queries: - documents = await self.knowledge_base(query, top_k=self.top_k, filters=None) - docs = [item for item in documents["documents"]] - for doc in docs: - if doc["document"] not in duplicates: - duplicates.add(doc["document"]) - retrieval_results.append(doc) - step_input = HumanMessage( - content=self.rag_prompt.format(query=prompt, documents=retrieval_results[:3]) - ) - chat_history: List[Message] = [step_input] - llm_resp = await self.run_llm(messages=chat_history) - - output_message = llm_resp.message - chat_history.append(output_message) - last_message = chat_history[-1] - response = AgentResponse( - text=last_message.content, - chat_history=chat_history, - steps=steps_taken, - status="FINISHED", - ) - self.memory.add_message(chat_history[0]) - self.memory.add_message(chat_history[-1]) - return response + documents = await self.knowledge_base(step_input, top_k=self.top_k, filters=None) + documents = [item for item in documents['documents'] if item["score"] > self.threshold] + return documents def _parse_results(self, results): left_index = results.find("{") right_index = results.rfind("}") - if left_index == -1 or right_index == -1: - # if invalid json, use Function Agent - return {"is_relevant": False} - try: - return json.loads(results[left_index : right_index + 1]) - except Exception: - # if invalid json, use Function Agent - return {"is_relevant": False} + return json.loads(results[left_index : right_index + 1]) \ No newline at end of file diff --git a/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py index 9c65be22c..b71967b47 100644 --- a/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py @@ -31,9 +31,9 @@ class LangChainRetrievalTool(Tool): def __init__( self, - name, - description, db, + name: Optional[str] = None, + description: Optional[str] = None, threshold: float = 0.0, input_type=None, output_type=None, @@ -59,7 +59,7 @@ async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, docs = [] for doc, score in documents: if score > self.threshold: - new_doc = {"document": doc.page_content} + new_doc = {"content": doc.page_content,'score':score} if self.return_meta_data: new_doc["meta"] = doc.metadata if "source" in doc.metadata: From ccea25bdb9847a29bb9b088ad998858db63a477f Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 19 Jan 2024 03:40:09 +0000 Subject: [PATCH 38/43] Update retrieval agents --- .../src/erniebot_agent/agents/__init__.py | 1 - .../agents/function_agent_with_retrieval.py | 29 +++++++++---------- .../tools/langchain_retrieval_tool.py | 2 +- ...unction_agent_with_retrieval_score_tool.py | 16 ---------- 4 files changed, 15 insertions(+), 33 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/__init__.py b/erniebot-agent/src/erniebot_agent/agents/__init__.py index 85a7aadfc..497aa4111 100644 --- a/erniebot-agent/src/erniebot_agent/agents/__init__.py +++ b/erniebot-agent/src/erniebot_agent/agents/__init__.py @@ -19,5 +19,4 @@ FunctionAgentWithRetrieval, FunctionAgentWithRetrievalScoreTool, FunctionAgentWithRetrievalTool, - FunctionAgentWithQueryPlanning, ) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index c95d8202d..db7b7469d 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -382,6 +382,7 @@ async def _maybe_retrieval( 1.严格按照【JSON格式】的形式输出:{'sub query1':'具体子问题1','sub_query2':'具体子问题2'} 原问题:{{query}} 提取的子问题:""" + class ContextAugmentedFunctionAgent(FunctionAgent): def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: float = 0.1, **kwargs): super().__init__(**kwargs) @@ -405,22 +406,21 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age agent=self, tool=self.search_tool, input_args=tool_args ) # Generate Answer - step_input = HumanMessage( - content=self.rag_prompt.format(query=prompt, documents=results) - ) - fake_chat_history: List[Message] = [] - fake_chat_history.append(step_input) + step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=results)) + fake_chat_history: List[Message] = [step_input] llm_resp = await self.run_llm(messages=fake_chat_history) output_message = llm_resp.message # Context Augmented query - rewrite_prompt = f"背景信息为:{output_message.content} \n 给定背景信息,而不是先验知识,选择相应的工具回答或者根据背景信息直接回答问题:{prompt}" - next_step_input = HumanMessage( - content=rewrite_prompt + rewrite_prompt = ( + f"背景信息为:{output_message.content} \n 给定背景信息,而不是先验知识,选择相应的工具回答或者根据背景信息直接回答问题:{prompt}" ) + next_step_input = HumanMessage(content=rewrite_prompt) chat_history.append(next_step_input) - # return response - tool_resp = ToolResponse(json=json.dumps(results, ensure_ascii=False), input_files=[], output_files=[]) + # Return response + tool_resp = ToolResponse( + json=json.dumps(results, ensure_ascii=False), input_files=[], output_files=[] + ) steps_taken.append( ToolStep( info=ToolInfo(tool_name=self.search_tool.tool_name, tool_args=tool_args), @@ -430,7 +430,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) ) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) - + # Execute next step return await super()._run(rewrite_prompt, files) else: _logger.info( @@ -441,8 +441,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age async def _query_rewrite(self, prompt): # Rewrite queries for retrieval step_input = HumanMessage(content=self.query_rewrite.format(query=prompt)) - fake_chat_history: List[Message] = [] - fake_chat_history.append(step_input) + fake_chat_history: List[Message] = [step_input] llm_resp = await self.run_llm(messages=fake_chat_history, functions=None) output_message = llm_resp.message queries = self._parse_results(output_message.content) @@ -457,10 +456,10 @@ async def _maybe_retrieval( step_input, ): documents = await self.knowledge_base(step_input, top_k=self.top_k, filters=None) - documents = [item for item in documents['documents'] if item["score"] > self.threshold] + documents = [item for item in documents["documents"] if item["score"] > self.threshold] return documents def _parse_results(self, results): left_index = results.find("{") right_index = results.rfind("}") - return json.loads(results[left_index : right_index + 1]) \ No newline at end of file + return json.loads(results[left_index : right_index + 1]) diff --git a/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py index b71967b47..e038ece9d 100644 --- a/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py @@ -59,7 +59,7 @@ async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, docs = [] for doc, score in documents: if score > self.threshold: - new_doc = {"content": doc.page_content,'score':score} + new_doc = {"content": doc.page_content, "score": score} if self.return_meta_data: new_doc["meta"] = doc.metadata if "source" in doc.metadata: diff --git a/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py b/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py index a70da5f37..e060624d0 100644 --- a/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py +++ b/erniebot-agent/tests/unit_tests/agents/test_function_agent_with_retrieval_score_tool.py @@ -121,22 +121,6 @@ async def test_functional_agent_with_retrieval_retrieval_score_tool_run_retrieva assert response.text == "Text response" # HumanMessage assert response.chat_history[0].content == "问题:Hello, world!,要求:请在第一步执行检索的操作,并且检索只允许调用一次" - # # AIMessage - # assert response.chat_history[1].function_call == { - # "name": "KnowledgeBaseTool", - # "thoughts": "这是一个检索的需求,我需要在KnowledgeBaseTool知识库中检索出与输入的query相关的段落,并返回给用户。", - # "arguments": '{"query": "Hello, world!"}', - # } - - # # FunctionMessag - # assert response.chat_history[2].name == "KnowledgeBaseTool" - # assert ( - # response.chat_history[2].content - # == '{"documents": [{"id": "495735246643269", "title": "城市管理执法办法.pdf", ' - # '"document": "住房和城乡建设部规章城市管理执法办法"}, {"id": "495735246643270", ' - # '"title": "城市管理执法办法.pdf", "document": "城市管理执法主管部门应当定期开展执法人员的培训和考核。"}]}' - # ) - # AIMessage assert response.chat_history[1].content == "Text response" From 4551c7cf590ceaf6128eb359502723718f62085f Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 19 Jan 2024 04:23:43 +0000 Subject: [PATCH 39/43] Add unitests --- .../erniebot_agent/agents/function_agent.py | 5 +- .../agents/function_agent_with_retrieval.py | 8 +- .../test_context_augmented_function_agent.py | 181 ++++++++++++++++++ 3 files changed, 190 insertions(+), 4 deletions(-) create mode 100644 erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 5e2d7746c..49f0d7cd6 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -124,9 +124,10 @@ def __init__( else: self._first_tools = [] - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: + async def _run( + self, prompt: str, files: Optional[Sequence[File]] = None, steps_taken: Optional[AgentStep] = [] + ) -> AgentResponse: chat_history: List[Message] = [] - steps_taken: List[AgentStep] = [] run_input = await HumanMessage.create_with_files( prompt, files or [], include_file_urls=self.file_needs_url diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index db7b7469d..90b869225 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -431,7 +431,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age ) await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp) # Execute next step - return await super()._run(rewrite_prompt, files) + return await super()._run(rewrite_prompt, files, steps_taken) else: _logger.info( f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}" @@ -445,10 +445,14 @@ async def _query_rewrite(self, prompt): llm_resp = await self.run_llm(messages=fake_chat_history, functions=None) output_message = llm_resp.message queries = self._parse_results(output_message.content) + duplicates = set() results = [] for sub_query in list(queries.values()): sub_results = await self._maybe_retrieval(sub_query) - results.extend(sub_results) + for doc in sub_results: + if doc["content"] not in duplicates: + duplicates.add(doc["content"]) + results.append(doc) return results async def _maybe_retrieval( diff --git a/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py b/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py new file mode 100644 index 000000000..3a2bd35c9 --- /dev/null +++ b/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py @@ -0,0 +1,181 @@ +import json +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from erniebot_agent.agents import ContextAugmentedFunctionAgent +from erniebot_agent.memory import AIMessage, HumanMessage +from erniebot_agent.retrieval import BaizhongSearch +from erniebot_agent.tools.baizhong_tool import BaizhongSearchTool +from tests.unit_tests.agents.common_util import EXAMPLE_RESPONSE +from tests.unit_tests.testing_utils.components import CountingCallbackHandler +from tests.unit_tests.testing_utils.mocks.mock_chat_models import ( + FakeERNIEBotWithPresetResponses, +) +from tests.unit_tests.testing_utils.mocks.mock_memory import FakeMemory +from tests.unit_tests.testing_utils.mocks.mock_tool import FakeTool + + +@pytest.fixture(scope="module") +def identity_tool(): + return FakeTool( + name="identity_tool", + description="This tool simply forwards the input.", + parameters={ + "type": "object", + "properties": { + "param": { + "type": "string", + "description": "Input parameter.", + } + }, + }, + responses={ + "type": "object", + "properties": { + "param": { + "type": "string", + "description": "Same as the input parameter.", + } + }, + }, + function=lambda param: {"param": param}, + ) + + +@pytest.fixture(scope="module") +def no_input_no_output_tool(): + return FakeTool( + name="no_input_no_output_tool", + description="This tool takes no input parameters and returns no output parameters.", + parameters={"type": "object", "properties": {}}, + responses={"type": "object", "properties": {}}, + function=lambda: {}, + ) + + +@pytest.mark.asyncio +async def test_functional_agent_with_retrieval_retrieval_score_tool_callbacks(identity_tool): + callback_handler = CountingCallbackHandler() + knowledge_base_name = "test" + access_token = "your access token" + knowledge_base_id = 111 + search_db = BaizhongSearch( + knowledge_base_name=knowledge_base_name, + access_token=access_token, + knowledge_base_id=knowledge_base_id if knowledge_base_id != "" else None, + ) + search_tool = BaizhongSearchTool(db=search_db, description="城市建筑设计相关规定") + llm = FakeERNIEBotWithPresetResponses( + responses=[ + AIMessage("Text response", function_call=None), + AIMessage('{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"}', function_call=None), + AIMessage("Text response", function_call=None), + AIMessage("Text response", function_call=None), + ] + ) + agent = ContextAugmentedFunctionAgent( + knowledge_base=search_tool, + llm=llm, + threshold=0.0, + tools=[identity_tool], + memory=FakeMemory(), + callbacks=[callback_handler], + ) + + await agent.run_llm([HumanMessage("Hello, world!")]) + assert callback_handler.llm_starts == 1 + assert callback_handler.llm_ends == 1 + assert callback_handler.llm_errors == 0 + + await agent.run_tool(identity_tool.tool_name, json.dumps({"param": "test"})) + assert callback_handler.tool_starts == 1 + assert callback_handler.tool_ends == 1 + assert callback_handler.tool_errors == 0 + with mock.patch("requests.post") as my_mock: + my_mock.return_value = MagicMock(status_code=200, json=lambda: EXAMPLE_RESPONSE) + await agent.run("Hello, world!") + assert callback_handler.run_starts == 1 + assert callback_handler.run_ends == 1 + assert callback_handler.run_errors == 0 + # call retrieval tool + assert callback_handler.tool_starts == 2 + assert callback_handler.tool_ends == 2 + assert callback_handler.tool_errors == 0 + + +@pytest.mark.asyncio +async def test_functional_agent_with_retrieval_retrieval_score_tool_run_retrieval(identity_tool): + knowledge_base_name = "test" + access_token = "your access token" + knowledge_base_id = 111 + + search_db = BaizhongSearch( + knowledge_base_name=knowledge_base_name, + access_token=access_token, + knowledge_base_id=knowledge_base_id if knowledge_base_id != "" else None, + ) + search_tool = BaizhongSearchTool(db=search_db, description="城市建筑设计相关规定") + llm = FakeERNIEBotWithPresetResponses( + responses=[ + AIMessage('{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"}', function_call=None), + AIMessage("Text response", function_call=None), + AIMessage("Text response", function_call=None), + ] + ) + agent = ContextAugmentedFunctionAgent( + knowledge_base=search_tool, + llm=llm, + threshold=0.0, + tools=[identity_tool], + memory=FakeMemory(), + ) + # Test retrieval success + with mock.patch("requests.post") as my_mock: + my_mock.return_value = MagicMock(status_code=200, json=lambda: EXAMPLE_RESPONSE) + response = await agent.run("Hello, world!") + assert response.text == "Text response" + # HumanMessage + assert ( + response.chat_history[0].content + == "背景信息为:Text response \n 给定背景信息,而不是先验知识,选择相应的工具回答或者根据背景信息直接回答问题:Hello, world!" + ) + # AIMessage + assert response.chat_history[1].content == "Text response" + + assert len(response.steps) == 1 + assert response.steps[0].info == { + "tool_name": "KnowledgeBaseTool", + "tool_args": '{"query": "Hello, world!"}', + } + assert ( + response.steps[0].result + == '[{"id": "495735246643269", "content": "住房和城乡建设部规章城市管理执法办法", "title": "城市管理执法办法.pdf", ' + '"score": 0.01162862777709961}, {"id": "495735246643270", "content": "城市管理执法主管部门应当定期开展执法人员的培训和考核。",' + ' "title": "城市管理执法办法.pdf", "score": 0.011362016201019287}]' + ) + + # Test retrieval failed + llm = FakeERNIEBotWithPresetResponses( + responses=[ + AIMessage('{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"}', function_call=None), + AIMessage("Text response", function_call=None), + AIMessage("Text response", function_call=None), + ] + ) + agent = ContextAugmentedFunctionAgent( + knowledge_base=search_tool, + llm=llm, + threshold=0.1, + tools=[identity_tool], + memory=FakeMemory(), + ) + with mock.patch("requests.post") as my_mock: + my_mock.return_value = MagicMock(status_code=200, json=lambda: EXAMPLE_RESPONSE) + response = await agent.run("Hello, world!") + assert response.text == "Text response" + # HumanMessage + assert response.chat_history[0].content == "Hello, world!" + # AIMessage + assert response.chat_history[1].content == "Text response" From f17e2255238793d43d6298f25e7d90ac9c5c7d81 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 19 Jan 2024 04:32:27 +0000 Subject: [PATCH 40/43] Update steps_taken --- .../erniebot_agent/agents/function_agent.py | 2 +- .../agents/function_agent_with_retrieval.py | 20 +++++++++++-------- .../tools/langchain_retrieval_tool.py | 4 ++-- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 49f0d7cd6..b3128f5de 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -125,7 +125,7 @@ def __init__( self._first_tools = [] async def _run( - self, prompt: str, files: Optional[Sequence[File]] = None, steps_taken: Optional[AgentStep] = [] + self, prompt: str, files: Optional[Sequence[File]] = None, steps_taken: List[AgentStep] = [] ) -> AgentResponse: chat_history: List[Message] = [] diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py index 90b869225..c5d86b869 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent_with_retrieval.py @@ -95,7 +95,9 @@ def __init__( self.search_tool = KnowledgeBaseTool() self.token_limit = token_limit - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: + async def _run( + self, prompt: str, files: Optional[Sequence[File]] = None, steps_taken: List[AgentStep] = [] + ) -> AgentResponse: results = await self._maybe_retrieval(prompt) if len(results["documents"]) > 0: # RAG branch @@ -109,7 +111,6 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=docs)) chat_history: List[Message] = [] chat_history.append(step_input) - steps_taken: List[AgentStep] = [] tool_ret_json = json.dumps(results, ensure_ascii=False) tool_resp = ToolResponse(json=tool_ret_json, input_files=[], output_files=[]) @@ -197,12 +198,13 @@ def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, **kwargs): self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) self.search_tool = KnowledgeBaseTool() - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: + async def _run( + self, prompt: str, files: Optional[Sequence[File]] = None, steps_taken: List[AgentStep] = [] + ) -> AgentResponse: results = await self._maybe_retrieval(prompt) if results["is_relevant"] is True: # RAG chat_history: List[Message] = [] - steps_taken: List[AgentStep] = [] tool_args = json.dumps({"query": prompt}, ensure_ascii=False) await self._callback_manager.on_tool_start( @@ -309,12 +311,13 @@ def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: fl self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) self.search_tool = KnowledgeBaseTool() - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: + async def _run( + self, prompt: str, files: Optional[Sequence[File]] = None, steps_taken: List[AgentStep] = [] + ) -> AgentResponse: results = await self._maybe_retrieval(prompt) if len(results["documents"]) > 0: # RAG chat_history: List[Message] = [] - steps_taken: List[AgentStep] = [] tool_args = json.dumps({"query": prompt}, ensure_ascii=False) await self._callback_manager.on_tool_start( @@ -393,13 +396,14 @@ def __init__(self, knowledge_base: BaizhongSearch, top_k: int = 3, threshold: fl self.search_tool = KnowledgeBaseTool() self.query_rewrite = PromptTemplate(REWRITE_PROMPT, input_variables=["query"]) - async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: + async def _run( + self, prompt: str, files: Optional[Sequence[File]] = None, steps_taken: List[AgentStep] = [] + ) -> AgentResponse: # Rewrite queries for retrieval results = await self._query_rewrite(prompt) if len(results) > 0: # RAG chat_history: List[Message] = [] - steps_taken: List[AgentStep] = [] tool_args = json.dumps({"query": prompt}, ensure_ascii=False) await self._callback_manager.on_tool_start( diff --git a/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py index e038ece9d..2ed444a99 100644 --- a/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/langchain_retrieval_tool.py @@ -32,8 +32,8 @@ class LangChainRetrievalTool(Tool): def __init__( self, db, - name: Optional[str] = None, - description: Optional[str] = None, + name: str = "This is a tool", + description: str = "This is the tool for search", threshold: float = 0.0, input_type=None, output_type=None, From b24e6857986c929cc65ac13ee59df720a967bc5b Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 19 Jan 2024 04:34:04 +0000 Subject: [PATCH 41/43] reformat --- .../agents/test_context_augmented_function_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py b/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py index 3a2bd35c9..1ef1c0006 100644 --- a/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py +++ b/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py @@ -56,7 +56,7 @@ def no_input_no_output_tool(): @pytest.mark.asyncio -async def test_functional_agent_with_retrieval_retrieval_score_tool_callbacks(identity_tool): +async def test_function_agent_with_retrieval_retrieval_score_tool_callbacks(identity_tool): callback_handler = CountingCallbackHandler() knowledge_base_name = "test" access_token = "your access token" @@ -106,7 +106,7 @@ async def test_functional_agent_with_retrieval_retrieval_score_tool_callbacks(id @pytest.mark.asyncio -async def test_functional_agent_with_retrieval_retrieval_score_tool_run_retrieval(identity_tool): +async def test_function_agent_with_retrieval_retrieval_score_tool_run_retrieval(identity_tool): knowledge_base_name = "test" access_token = "your access token" knowledge_base_id = 111 From a860defd6d3e41212a366164df7188b78482566a Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 19 Jan 2024 04:36:13 +0000 Subject: [PATCH 42/43] Update erniebot --- erniebot/src/erniebot/resources/chat_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index 95263e1f1..ee0982836 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -573,4 +573,4 @@ def to_message(self) -> Dict[str, Any]: message["function_call"] = self.function_call else: message["content"] = self.result - return message \ No newline at end of file + return message From e3dd45fd033ea9e08cc1fad4a94da5a4f0998d88 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 19 Jan 2024 04:43:40 +0000 Subject: [PATCH 43/43] Update unitest --- .../agents/test_context_augmented_function_agent.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py b/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py index 1ef1c0006..e246fbbd9 100644 --- a/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py +++ b/erniebot-agent/tests/unit_tests/agents/test_context_augmented_function_agent.py @@ -56,7 +56,7 @@ def no_input_no_output_tool(): @pytest.mark.asyncio -async def test_function_agent_with_retrieval_retrieval_score_tool_callbacks(identity_tool): +async def test_function_agent_with_retrieval_context_augmented_callbacks(identity_tool): callback_handler = CountingCallbackHandler() knowledge_base_name = "test" access_token = "your access token" @@ -106,7 +106,7 @@ async def test_function_agent_with_retrieval_retrieval_score_tool_callbacks(iden @pytest.mark.asyncio -async def test_function_agent_with_retrieval_retrieval_score_tool_run_retrieval(identity_tool): +async def test_function_agent_with_retrieval_context_augmented_run_retrieval(identity_tool): knowledge_base_name = "test" access_token = "your access token" knowledge_base_id = 111 @@ -143,8 +143,7 @@ async def test_function_agent_with_retrieval_retrieval_score_tool_run_retrieval( ) # AIMessage assert response.chat_history[1].content == "Text response" - - assert len(response.steps) == 1 + assert len(response.steps) == 2 assert response.steps[0].info == { "tool_name": "KnowledgeBaseTool", "tool_args": '{"query": "Hello, world!"}',