Skip to content

Commit

Permalink
Merge pull request #17 from qingzhong1/eb8
Browse files Browse the repository at this point in the history
update example_group
  • Loading branch information
w5688414 authored Jan 9, 2024
2 parents 45a29ee + 44807fe commit 76ea7ee
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tools.utils import JsonUtil, ReportCallbackHandler

from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, SystemMessage
from erniebot_agent.prompt import PromptTemplate
Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(
else:
self._callback_manager = callbacks

async def run(self, report: Union[str, dict]) -> AgentResponse:
async def run(self, report: Union[str, dict[str, str]]) -> dict:
if isinstance(report, dict):
report = report["report"]
await self._callback_manager.on_run_start(agent=self, agent_name=self.name, prompt=report)
Expand Down
120 changes: 57 additions & 63 deletions erniebot-agent/applications/erniebot_researcher/group_agent.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import logging
import random
import re
import sys
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Union

from EditorActorAgent import EditorActorAgent
from RankingAgent import RankingAgent
from ResearchAgent import ResearchAgent
from ReviserActorAgent import ReviserActorAgent
from tools.utils import erniebot_chat
from editor_actor_agent import EditorActorAgent
from polish_agent import PolishAgent
from ranking_agent import RankingAgent
from research_agent import ResearchAgent
from reviser_actor_agent import ReviserActorAgent
from tools.utils import JsonUtil

from erniebot_agent.agents.base import Agent
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import AIMessage, HumanMessage

logger = logging.getLogger(__name__)
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]


class GroupChat:
class GroupChat(JsonUtil):
def __init__(
self,
agents: List[Agent],
agents,
llm: BaseERNIEBot,
llm_long: BaseERNIEBot,
max_round: int = 10,
admin_name: str = "Admin",
func_call_filter: bool = True,
Expand All @@ -28,6 +31,8 @@ def __init__(
):
self.agents = agents
self.max_round = max_round
self.llm = llm
self.llm_long = llm_long
self.admin_name = admin_name
self.func_call_filter = func_call_filter
self.speaker_selection_method = speaker_selection_method
Expand All @@ -38,11 +43,11 @@ def agent_names(self) -> List[str]:
"""Return the names of the agents in the group chat."""
return [agent.name for agent in self.agents]

def agent_by_name(self, name: str) -> Agent:
def agent_by_name(self, name: str):
"""Returns the agent with a given name."""
return self.agents[self.agent_names.index(name)]

def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
def next_agent(self, agent, agents):
"""Return the next agent in the list."""
idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1
# Return the next agent
Expand All @@ -54,20 +59,20 @@ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]

def select_speaker_msg(self, agents: List[Agent]) -> str:
def select_speaker_msg(self, agents) -> str:
return f"""您正在玩角色扮演游戏。可以使用以下角色:
{self._participant_roles(agents)}.
阅读下面的对话。
{[agent.name for agent in agents]}中选择下一个角色来扮演。仅返回扮演的角色。"""

def select_speaker_prompt(self, agents: List[Agent]) -> str:
def select_speaker_prompt(self, agents) -> str:
strs = ""
for i in agents:
strs += i.name + ":" + i.system_message + "\n"
return f"阅读下面的对话。 从{[agent.name for agent in agents]} 中选择下一个角色来扮演。仅返回扮演的角色。" + strs

def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
def manual_select_speaker(self, agents):
logger.info("请从以下列表中选择下一位Agent:")
_n_agents = len(agents)
for i in range(_n_agents):
Expand All @@ -93,7 +98,7 @@ def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
logger.info(f"输入无效。请输入 1 到 {_n_agents} 之间的数字。")
return None

def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agent], List[Agent]]:
def _prepare_and_select_agents(self, last_speaker):
if self.speaker_selection_method.lower() not in _VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
Expand Down Expand Up @@ -135,36 +140,28 @@ def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agen
selected_agent = None
return selected_agent, agents

def select_speaker(self, last_speaker: Agent, messages: List):
async def select_speaker(self, last_speaker, messages: List):
"""Select the next speaker."""
selected_agent, agents = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
# auto speaker selection
respose = erniebot_chat(messages=messages, system=self.select_speaker_prompt(agents))
respose = await self.llm_long.chat(messages=messages, system=self.select_speaker_prompt(agents))
if not respose:
return self.next_agent(last_speaker, agents)

# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(respose, agents)
if len(mentions) == 1:
respose = next(iter(mentions))
else:
# Return the result
try:
mentions = self._mentioned_agents(respose.content, agents)
return self.agent_by_name(mentions)
except Exception as e:
logger.error(e)
logger.warning(
"GroupChat select_speaker failed to resolve the next speaker's name. "
+ f"This is because the speaker selection OAI call returned:\n{respose}"
)

# Return the result
try:
return self.agent_by_name(respose)
except ValueError:
return self.next_agent(last_speaker, agents)

async def a_select_speaker(self, last_speaker: Agent, messages):
"""Select the next speaker."""
return self.select_speaker(last_speaker, messages)

def _participant_roles(self, agents) -> str:
# Default to all agents registered
if agents is None:
Expand All @@ -180,20 +177,13 @@ def _participant_roles(self, agents) -> str:
roles.append(f"{agent.name}: {agent.system_message}".strip())
return "\n".join(roles)

def _mentioned_agents(self, message_content: str, agents: List[Agent]) -> Dict:
def _mentioned_agents(self, message_content: str, agents) -> str:
# Cast message content to str
mentions = dict()
for agent in agents:
regex = (
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
) # Finds agent mentions, taking word boundaries into account
count = len(re.findall(regex, message_content)) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
return mentions
next_agent_name = self.parse_json(message_content)["next_agent_name"]
return next_agent_name


class GroupChatManager(Agent):
class GroupChatManager:
"""(In preview) A chat manager agent that can manage a group chat of multiple agents."""

def __init__(
Expand All @@ -212,23 +202,26 @@ def __init__(
self.human_input_mode = human_input_mode
self.system_message = system_message

async def _async_run(
async def run(
self,
query: str,
report: str,
speaker: Agent,
) -> Union[str, Dict, None]:
report,
speaker,
):
"""Run a group chat."""
report_list = [report]
messages = [{"role": "user", "content": "你需要对生成的报告进行质量检测,请调用已有的各种助手完成这个任务,每次只调用1个助手。请你只返回助手的名字"}]
content = """你需要对生成的报告进行质量检测,请调用已有的各种助手完成这个任务,每次只调用1个助手。
现在已经存在一份报告,你必须对它进行质量检测,检测后,如果你认为报告质量没有达到要求,你可以调用报告生成助手,重新生成报告。
请你需要返回一个json格式的字符串,{"next_agent_name":"下一次调用助手的名字"}"""
messages: List[Union[HumanMessage, AIMessage]] = [HumanMessage(content)]
notes = ""
for i in range(self.groupchat.max_round):
if i == self.groupchat.max_round - 1:
# the last round
break
try:
# select the next speaker
speaker = self.groupchat.select_speaker(speaker, messages)
speaker = await self.groupchat.select_speaker(speaker, messages)
# if speaker
except KeyboardInterrupt:
# let the admin agent speak if interrupted
Expand All @@ -239,27 +232,28 @@ async def _async_run(
# admin agent is not found in the participants
raise
if isinstance(speaker, EditorActorAgent):
respose = await speaker._async_run(report)
respose = await speaker.run(report)
notes = respose.get("notes", "")
messages.append(
{"role": "assistant", "content": "调用" + speaker.name + "得到的结果为" + str(respose)}
)
messages.append(AIMessage("调用" + speaker.name + "得到的结果为" + str(respose)))
elif isinstance(speaker, ReviserActorAgent):
report_list.append(await speaker._async_run(report, notes))
report_list.append(await speaker.run(report, notes))
report = report_list[-1]
messages.append({"role": "assistant", "content": "调用" + speaker.name + "对报告进行了修订"})
messages.append(AIMessage("调用" + speaker.name + "对报告进行了修订"))
elif isinstance(speaker, ResearchAgent):
report, _ = await speaker._async_run(query)
report_list.append(report)
messages.append({"role": "assistant", "content": "调用" + speaker.name + "重新生成了一份报告"})
report_str, paragraphs = await speaker.run(query)
report_list.append({"report": report_str, "paragraphs": paragraphs})
messages.append(AIMessage("调用" + speaker.name + "重新生成了一份报告"))
elif isinstance(speaker, RankingAgent):
report = await speaker._async_run(report_list, query)
messages.append({"role": "assistant", "content": "调用" + speaker.name + "对多个报告进行了排序,得到最优的报告"})
report_list, report = await speaker.run(report_list, query)
messages.append(AIMessage("调用" + speaker.name + "对多个报告进行了排序,得到最优的报告"))
elif isinstance(speaker, PolishAgent):
report_str, _ = await speaker.run(report["report"], report["paragraphs"])
report = {"report": report_str, "paragraphs": report["paragraphs"]}
report_list.append(report)

if self.human_input_mode:
reply = input("是否停止,如果您认为生成的report符合要求,则请输入yes,否则输入no\n请输入:")
if reply == "yes":
break
messages.append(
{"role": "user", "content": "你需要对生成的报告进行质量检测,请调用已有的各种助手完成这个任务,每次只调用1个助手。请你只返回助手的名字"}
)
messages.append(HumanMessage(content))
return report
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tools.utils import JsonUtil, ReportCallbackHandler, add_citation, write_md_to_pdf

from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, SystemMessage
from erniebot_agent.prompt import PromptTemplate
Expand Down Expand Up @@ -62,7 +61,7 @@ def __init__(
else:
self._callback_manager = callbacks

async def run(self, report: str, summarize=None) -> AgentResponse:
async def run(self, report: str, summarize=None):
await self._callback_manager.on_run_start(agent=self, prompt=report)
agent_resp = await self._run(report, summarize)
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tools.utils import JsonUtil, ReportCallbackHandler

from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, SystemMessage
from erniebot_agent.prompt import PromptTemplate
Expand Down Expand Up @@ -53,7 +52,7 @@ def __init__(
else:
self._callback_manager = callbacks

async def run(self, list_reports: List[Union[str, dict]], query: str) -> AgentResponse:
async def run(self, list_reports: List[Union[str, dict]], query: str):
await self._callback_manager.on_run_start(agent=self, agent_name=self.name, prompt=query)
agent_resp = await self._run(query=query, list_reports=list_reports)
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tools.utils import JsonUtil, ReportCallbackHandler

from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, SystemMessage
from erniebot_agent.prompt.prompt_template import PromptTemplate
Expand Down Expand Up @@ -39,7 +38,7 @@ def __init__(
else:
self._callback_manager = callbacks

async def run(self, draft: Union[str, Dict], notes: str) -> AgentResponse:
async def run(self, draft: Union[str, Dict], notes: str) -> str:
if isinstance(draft, dict):
await self._callback_manager.on_run_start(agent=self, prompt=draft["report"])
else:
Expand Down
Loading

0 comments on commit 76ea7ee

Please sign in to comment.