diff --git a/erniebot-agent/applications/erniebot_researcher/README.md b/erniebot-agent/applications/erniebot_researcher/README.md index e34783cf..020d8341 100644 --- a/erniebot-agent/applications/erniebot_researcher/README.md +++ b/erniebot-agent/applications/erniebot_researcher/README.md @@ -30,6 +30,7 @@ Agents利用ernie-4.0和ernie-longtext来完成研究任务, ernie-4.0主要 + 多个报告Agent并行生成,并保持一定的多样性。 + 使用思维链技术对多个报告进行质量评估和排序,克服伪随机性,并选择最优的报告。 + 使用反思机制对报告进行修订和完善。 ++ 新增检索增强和chain of verification对事实进行校验 + 使用润色机制提升报告的整体可读性,融合更多的细节描述。 **注意** diff --git a/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py b/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py new file mode 100644 index 00000000..5b2047ab --- /dev/null +++ b/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py @@ -0,0 +1,244 @@ +import logging +import re +from typing import Any, List, Optional + +from tools.utils import JsonUtil, ReportCallbackHandler + +from erniebot_agent.chat_models.erniebot import BaseERNIEBot +from erniebot_agent.memory import HumanMessage, Message, SystemMessage +from erniebot_agent.prompt import PromptTemplate + +logger = logging.getLogger(__name__) +PLAN_VERIFICATIONS_PROMPT = """ +为了验证给出的内容中数字性表述的准确性,您需要创建一系列验证问题, +用于测试原始基线响应中的事实主张。例如,如果长格式响应的一部分包含 +“墨西哥-美国战争是 1846 年至 1848 年美国和墨西哥之间的武装冲突”, +那么一种可能的验证问题可以是:“墨西哥-美国战争何时开始以及结束?” +给定内容:{{base_context}} +您需要按照列表输出,输出段落中的事实和相应的验证问题。 +[{"fact": <段落中的事实>, "question": <验证问题,通过结合查询和事实生成>}, +{"fact": <段落中的事实>, "question": <验证问题,通过结合查询和事实生成>}, ...] +""" +ANWSER_PROMPT = """ +根据外部知识回答问题。如果给出的外部知识不能解答提出的问题, +请直接输出"无法回答",无需提供答案。给定问题:\n{{question}}\n +外部知识:{{content}}\n回答: +""" +CHECK_CLAIM_PROMPT = """ +根据给定的问题和回答,判断事实中数字描述是否正确。如果认为事实中的数字描述不正确, +请根据问题和回答提供修正。您的输出应为 JSON 格式: +{"is_correct": <事实是否正确>, "modify": <对不正确的事实进行修正>} +给定问题: {{question}}\n回答: {{answer}}\n事实: {{claim}} +""" + +FINAL_RESPONSE_PROMPT = """ +根据提供的背景知识,对原始内容进行改写。确保改写后的内容中的数字与背景知识中的数字一致。 +您必须修正原始内容中的数字。原始内容:{{origin_content}}\n背景知识:{{context}} +改写内容: +""" + + +class FactCheckerAgent(JsonUtil): + DEFAULT_SYSTEM_MESSAGE = "你是一个事实检查助手,你的任务就是检查文本中的事实描述是否正确" + + def __init__( + self, + name: str, + llm: BaseERNIEBot, + retriever_db: Any, + system_message: Optional[SystemMessage] = None, + callbacks=None, + ): + """ + Initialize a fact_checker agent. + + args: + name: The name of the agent. + llm: An LLM for the agent to use. + retriever_db: A database for the agent to use. + system_message: A message to be displayed when the agent starts. + callbacks: A callback handler for the agent. + """ + self.name = name + self.llm = llm + self.retriever_db = retriever_db + self.prompt_plan_verifications = PromptTemplate( + PLAN_VERIFICATIONS_PROMPT, input_variables=["base_context"] + ) + self.prompt_anwser = PromptTemplate(ANWSER_PROMPT, input_variables=["question", "content"]) + self.prompt_check_claim = PromptTemplate( + CHECK_CLAIM_PROMPT, input_variables=["question", "answer", "claim"] + ) + self.prompt_final_response = PromptTemplate( + FINAL_RESPONSE_PROMPT, input_variables=["origin_content", "context"] + ) + self.system_message = ( + system_message.content if system_message is not None else self.DEFAULT_SYSTEM_MESSAGE + ) + if callbacks is None: + self._callback_manager = ReportCallbackHandler() + else: + self._callback_manager = callbacks + + async def run(self, report: str): + """ + The main logic of running the agent. + + Args: + report: Entered report text. + Returns: + The results of the agent's operation. + """ + await self._callback_manager.on_run_start( + agent=self, agent_name=self.name, prompt=self.system_message + ) + agent_resp = await self._run(report=report) + await self._callback_manager.on_run_end(agent=self, response=agent_resp) + return agent_resp + + async def generate_anwser(self, question: str, context: str): + """ + Generate answers to questions based on background knowledge + + Args: + question: Indicates a question to be answered. + context: Represents background knowledge relevant to the problem. + + Returns: + Generated answers to questions. + """ + messages: List[Message] = [ + HumanMessage(content=self.prompt_anwser.format(question=question, content=context)) + ] + responese = await self.llm.chat(messages) + result = responese.content + return result + + async def check_claim(self, question: str, answer: str, claim: str): + """ + Use a large language model to conduct a conversation, verify a fact, + and suggest corrections if the fact is incorrect + + Args: + question: represents a fact-checking question + answer: represents a fact-checking answer + claim: indicates a fact that need to be verified + + Returns: + A dictionary containing verification results, + including whether the facts are correct and suggestions for correction. + """ + messages: List[Message] = [ + HumanMessage( + content=self.prompt_check_claim.format(question=question, answer=answer, claim=claim) + ) + ] + responese = await self.llm.chat(messages) + result = responese.content + result = self.parse_json(result) + return result + + async def verifications(self, facts_problems: List[dict]): + """ + Answer questions using external knowledge and then use the answers to the questions to verify + relevant facts. As it processes each question and fact pair, it obtains the context + relevant to the question, generates an answer to the question, checks whether the fact + is correct, and records the verification results. + + Args: + facts_problems: A list of dictionaries containing questions and related facts. + + Returns: + Updated dictionary list of verified questions and facts. + """ + for item in facts_problems: + question = item["question"] + claim = item["fact"] + context = self.retriever_db.search(question) + context = [i["content"] for i in context] + item["evidence"] = context + anwser = await self.generate_anwser(question, context) + item["anwser"] = anwser + result = await self.check_claim(question, anwser, claim) + item["is_correct"] = result["is_correct"] + if result["is_correct"] is False: + item["modify"] = result["modify"] + else: + item["modify"] = claim + self._callback_manager._agent_info(msg=item["modify"], subject="事实验证的结果", state="End") + return facts_problems + + async def generate_final_response(self, content: str, verifications: List[dict]): + """ + If the original factual questions pass fact verification, + the original content will be returned directly. + Otherwise, the original content will be corrected based + on the results of factual verification. + + Args: + content: Original text content. + verifications: List of dictionaries containing fact verification results. + + Returns: + The final generated reply content. + """ + if all([item["is_correct"] for item in verifications]): + return content + else: + context = "".join([item["modify"] for item in verifications]) + messages: List[Message] = [ + HumanMessage( + content=self.prompt_final_response.format(origin_content=content, context=context) + ) + ] + resulte = await self.llm.chat(messages) + result = resulte.content + return result + + async def report_fact(self, report: str): + """ + Filter out sentences containing numbers in text. + Extract facts from the filtered sentences. + Extract validation questions and verify each extracted fact. + Example: + "Mexican-American War + was the armed conflict between the United States and Mexico from 1846 to 1848," then one possibility + A validation question to check these dates could be When did the Mexican-American War begin and end? + + Args: + report: The original text content. + + Returns: + The final generated reply content. + """ + report_list = report.split("\n\n") + text = [] + for item in report_list: + if item.strip()[0] == "#": + text.append(item) + else: + contains_numbers = re.findall(r"\b\d+\b", item) + if contains_numbers: + messages: List[Message] = [ + HumanMessage(content=self.prompt_plan_verifications.format(base_context=item)) + ] + responese = await self.llm.chat(messages) + result: List[dict] = self.parse_json(responese.content) + fact_check_result: List[dict] = await self.verifications(result) + new_item: str = await self.generate_final_response(item, fact_check_result) + text.append(new_item) + else: + text.append(item) + return "\n\n".join(text) + + async def _run(self, report: str): + """ + The main logic of running the agent. + + Args: + report: Entered report text. + Returns: + Processed report text. + """ + report = await self.report_fact(report) + return report diff --git a/erniebot-agent/applications/erniebot_researcher/research_team.py b/erniebot-agent/applications/erniebot_researcher/research_team.py index fa2d057e..f904b04d 100644 --- a/erniebot-agent/applications/erniebot_researcher/research_team.py +++ b/erniebot-agent/applications/erniebot_researcher/research_team.py @@ -2,6 +2,7 @@ from typing import List, Optional from editor_actor_agent import EditorActorAgent +from fact_check_agent import FactCheckerAgent from polish_agent import PolishAgent from ranking_agent import RankingAgent from research_agent import ResearchAgent @@ -16,6 +17,7 @@ def __init__( ranker_actor: RankingAgent, editor_actor: EditorActorAgent, reviser_actor: ReviserActorAgent, + checker_actor: FactCheckerAgent, polish_actor: Optional[PolishAgent] = None, user_agent: Optional[UserProxyAgent] = None, use_reflection: bool = False, @@ -25,6 +27,7 @@ def __init__( self.revise_actor_instance = reviser_actor self.ranker_actor_instance = ranker_actor self.polish_actor_instance = polish_actor + self.checker_actor_instance = checker_actor self.user_agent = user_agent self.polish_actor = polish_actor self.use_reflection = use_reflection @@ -77,9 +80,9 @@ async def run(self, query, iterations=3): immedia_report = list_reports[0] revised_report = immedia_report - + checked_report = await self.checker_actor_instance.run(report=revised_report["report"]) revised_report, path = await self.polish_actor_instance.run( - report=revised_report["report"], + report=checked_report, summarize=revised_report["paragraphs"], ) return revised_report, path diff --git a/erniebot-agent/applications/erniebot_researcher/sample_report_example.py b/erniebot-agent/applications/erniebot_researcher/sample_report_example.py index c3f0ac40..bf9d1bf3 100644 --- a/erniebot-agent/applications/erniebot_researcher/sample_report_example.py +++ b/erniebot-agent/applications/erniebot_researcher/sample_report_example.py @@ -5,6 +5,7 @@ import time from editor_actor_agent import EditorActorAgent +from fact_check_agent import FactCheckerAgent from langchain.embeddings.openai import OpenAIEmbeddings from polish_agent import PolishAgent from ranking_agent import RankingAgent @@ -160,6 +161,7 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long): dir_path=target_path, report_type=args.report_type, ) + checker_actor = FactCheckerAgent(name="fact_check", llm=llm, retriever_db=retriever_sets["full_text"]) ranker_actor = RankingAgent( llm=llm, llm_long=llm_long, @@ -171,6 +173,7 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long): "editor_actor": editor_actor, "reviser_actor": reviser_actor, "ranker_actor": ranker_actor, + "checker_actor": checker_actor, "polish_actor": polish_actor, } diff --git a/erniebot-agent/applications/erniebot_researcher/ui.py b/erniebot-agent/applications/erniebot_researcher/ui.py index 5222119c..e4a22ac4 100644 --- a/erniebot-agent/applications/erniebot_researcher/ui.py +++ b/erniebot-agent/applications/erniebot_researcher/ui.py @@ -5,6 +5,7 @@ import gradio as gr from editor_actor_agent import EditorActorAgent +from fact_check_agent import FactCheckerAgent from langchain.embeddings.openai import OpenAIEmbeddings from polish_agent import PolishAgent from ranking_agent import RankingAgent @@ -168,6 +169,12 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): llm_long=llm_long, callbacks=ReportCallbackHandler(logger=logger), ) + checker_actor = FactCheckerAgent( + name="fact_check", + llm=llm, + retriever_db=retriever_sets["full_text"], + callbacks=ReportCallbackHandler(logger=logger), + ) polish_actor = PolishAgent( name="polish", llm=llm, @@ -184,6 +191,7 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): "editor_actor": editor_actor, "reviser_actor": reviser_actor, "ranker_actor": ranker_actor, + "checker_actor": checker_actor, "polish_actor": polish_actor, }