From 2cb543db77176b62c75085f6d4a6a3b40dbfbee8 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:30:54 +0200 Subject: [PATCH] Multi agents with manager (#32687) * Add Multi agents with a hierarchical system --- src/transformers/agents/__init__.py | 4 +- src/transformers/agents/agents.py | 196 ++++++++++++++---- src/transformers/agents/default_tools.py | 2 +- src/transformers/agents/prompts.py | 16 +- src/transformers/agents/python_interpreter.py | 2 +- tests/agents/test_agents.py | 25 ++- 6 files changed, 193 insertions(+), 52 deletions(-) diff --git a/src/transformers/agents/__init__.py b/src/transformers/agents/__init__.py index 4235d4c0d70d6c..438bd313b5e46e 100644 --- a/src/transformers/agents/__init__.py +++ b/src/transformers/agents/__init__.py @@ -24,7 +24,7 @@ _import_structure = { - "agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], + "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "llm_engine": ["HfApiEngine", "TransformersEngine"], "monitoring": ["stream_to_gradio"], "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], @@ -45,7 +45,7 @@ _import_structure["translation"] = ["TranslationTool"] if TYPE_CHECKING: - from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox + from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .llm_engine import HfApiEngine, TransformersEngine from .monitoring import stream_to_gradio from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 8152b3213b8f71..5a4aea28d97061 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -57,8 +57,11 @@ class CustomFormatter(logging.Formatter): bold_yellow = "\x1b[33;1m" red = "\x1b[31;20m" green = "\x1b[32;20m" + bold_green = "\x1b[32;20;1m" bold_red = "\x1b[31;1m" bold_white = "\x1b[37;1m" + orange = "\x1b[38;5;214m" + bold_orange = "\x1b[38;5;214;1m" reset = "\x1b[0m" format = "%(message)s" @@ -66,11 +69,14 @@ class CustomFormatter(logging.Formatter): logging.DEBUG: grey + format + reset, logging.INFO: format, logging.WARNING: bold_yellow + format + reset, - 31: reset + format + reset, - 32: green + format + reset, - 33: bold_white + format + reset, logging.ERROR: red + format + reset, logging.CRITICAL: bold_red + format + reset, + 31: reset + format + reset, + 32: green + format + reset, + 33: bold_green + format + reset, + 34: bold_white + format + reset, + 35: orange + format + reset, + 36: bold_orange + format + reset, } def format(self, record): @@ -311,12 +317,32 @@ class AgentGenerationError(AgentError): def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) prompt = prompt_template.replace("<>", tool_descriptions) + if "<>" in prompt: tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()] prompt = prompt.replace("<>", ", ".join(tool_names)) + return prompt +def show_agents_descriptions(managed_agents: list): + managed_agents_descriptions = """ +You can also give requests to team members. +Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request. +Given that this team member is a real human, you should be very verbose in your request. +Here is a list of the team members that you can call:""" + for agent in managed_agents.values(): + managed_agents_descriptions += f"\n- {agent.name}: {agent.description}" + return managed_agents_descriptions + + +def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str: + if managed_agents is not None: + return prompt_template.replace("<>", show_agents_descriptions(managed_agents)) + else: + return prompt_template.replace("<>", "") + + def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str: if "<>" not in prompt_template: raise AgentError("Tag '<>' should be provided in the prompt.") @@ -335,8 +361,8 @@ def __init__( tool_parser=parse_json_tool_call, add_base_tools: bool = False, verbose: int = 0, - memory_verbose: bool = False, grammar: Dict[str, str] = None, + managed_agents: List = None, ): self.agent_name = self.__class__.__name__ self.llm_engine = llm_engine @@ -350,6 +376,10 @@ def __init__( self.tool_parser = tool_parser self.grammar = grammar + self.managed_agents = None + if managed_agents is not None: + self.managed_agents = {agent.name: agent for agent in managed_agents} + if isinstance(tools, Toolbox): self._toolbox = tools if add_base_tools: @@ -364,10 +394,10 @@ def __init__( self.system_prompt = format_prompt_with_tools( self._toolbox, self.system_prompt_template, self.tool_description_template ) + self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) self.prompt = None self.logs = [] self.task = None - self.memory_verbose = memory_verbose if verbose == 0: logger.setLevel(logging.WARNING) @@ -388,13 +418,14 @@ def initialize_for_run(self): self.system_prompt_template, self.tool_description_template, ) + self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) if hasattr(self, "authorized_imports"): self.system_prompt = format_prompt_with_imports( self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) ) self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] - self.logger.warn("======== New task ========") - self.logger.log(33, self.task) + self.logger.log(33, "======== New task ========") + self.logger.log(34, self.task) self.logger.debug("System prompt is as follows:") self.logger.debug(self.system_prompt) @@ -444,12 +475,12 @@ def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> if "error" in step_log or "observation" in step_log: if "error" in step_log: message_content = ( - f"[OUTPUT OF STEP {i}] Error: " + f"[OUTPUT OF STEP {i}] -> Error:\n" + str(step_log["error"]) + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" ) elif "observation" in step_log: - message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}" + message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log['observation']}" tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content} memory.append(tool_response_message) @@ -477,7 +508,7 @@ def extract_action(self, llm_output: str, split_token: str) -> str: raise AgentParsingError( f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" ) - return rationale, action + return rationale.strip(), action.strip() def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: """ @@ -488,29 +519,44 @@ def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). arguments (Dict[str, str]): Arguments passed to the Tool. """ - if tool_name not in self.toolbox.tools: - error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(self.toolbox.tools.keys())}." + available_tools = self.toolbox.tools + if self.managed_agents is not None: + available_tools = {**available_tools, **self.managed_agents} + if tool_name not in available_tools: + error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." self.logger.error(error_msg, exc_info=1) raise AgentExecutionError(error_msg) try: if isinstance(arguments, str): - observation = self.toolbox.tools[tool_name](arguments) - else: + observation = available_tools[tool_name](arguments) + elif isinstance(arguments, dict): for key, value in arguments.items(): # if the value is the name of a state variable like "image.png", replace it with the actual value if isinstance(value, str) and value in self.state: arguments[key] = self.state[value] - observation = self.toolbox.tools[tool_name](**arguments) + observation = available_tools[tool_name](**arguments) + else: + raise AgentExecutionError( + f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." + ) return observation except Exception as e: - raise AgentExecutionError( - f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" - f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(self.toolbox.tools[tool_name])}" - ) - - def log_code_action(self, code_action: str) -> None: - self.logger.warning("==== Agent is executing the code below:") + if tool_name in self.toolbox.tools: + raise AgentExecutionError( + f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" + f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}" + ) + elif tool_name in self.managed_agents: + raise AgentExecutionError( + f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" + f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" + ) + + def log_rationale_code_action(self, rationale: str, code_action: str) -> None: + self.logger.warning("=== Agent thoughts:") + self.logger.log(31, rationale) + self.logger.warning(">>> Agent is executing the code below:") if is_pygments_available(): self.logger.log( 31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord")) @@ -612,12 +658,12 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): # Parse try: - _, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") except Exception as e: self.logger.debug( f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" ) - code_action = llm_output + rationale, code_action = "", llm_output try: code_action = self.parse_code_blob(code_action) @@ -627,7 +673,7 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): return error_msg # Execute - self.log_code_action(code_action) + self.log_rationale_code_action(rationale, code_action) try: available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} output = self.python_evaluator( @@ -813,6 +859,9 @@ def planning_step(self, task, is_first_step: bool = False, iteration: int = None "content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format( task=task, tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), + managed_agents_descriptions=( + show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else "" + ), answer_facts=answer_facts, ), } @@ -829,8 +878,8 @@ def planning_step(self, task, is_first_step: bool = False, iteration: int = None {answer_facts} ```""".strip() self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) - self.logger.debug("===== Initial plan: =====") - self.logger.debug(final_plan_redaction) + self.logger.log(36, "===== Initial plan =====") + self.logger.log(35, final_plan_redaction) else: # update plan agent_memory = self.write_inner_memory_from_logs( summary_mode=False @@ -857,6 +906,9 @@ def planning_step(self, task, is_first_step: bool = False, iteration: int = None "content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format( task=task, tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), + managed_agents_descriptions=( + show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else "" + ), facts_update=facts_update, remaining_steps=(self.max_iterations - iteration), ), @@ -872,8 +924,8 @@ def planning_step(self, task, is_first_step: bool = False, iteration: int = None {facts_update} ```""" self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) - self.logger.debug("===== Updated plan: =====") - self.logger.debug(final_plan_redaction) + self.logger.log(36, "===== Updated plan =====") + self.logger.log(35, final_plan_redaction) class ReactJsonAgent(ReactAgent): @@ -945,7 +997,9 @@ def step(self): current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} # Execute - self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}") + self.logger.warning("=== Agent thoughts:") + self.logger.log(31, rationale) + self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") if tool_name == "final_answer": if isinstance(arguments, dict): if "answer" in arguments: @@ -961,6 +1015,8 @@ def step(self): current_step_logs["final_answer"] = answer return current_step_logs else: + if arguments is None: + arguments = {} observation = self.execute_tool_call(tool_name, arguments) observation_type = type(observation) if observation_type == AgentText: @@ -1050,12 +1106,12 @@ def step(self): except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") - self.logger.debug("===== Output message of the LLM: =====") + self.logger.debug("=== Output message of the LLM:") self.logger.debug(llm_output) current_step_logs["llm_output"] = llm_output # Parse - self.logger.debug("===== Extracting action =====") + self.logger.debug("=== Extracting action ===") try: rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") except Exception as e: @@ -1072,22 +1128,30 @@ def step(self): current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} # Execute - self.log_code_action(code_action) + self.log_rationale_code_action(rationale, code_action) try: + static_tools = { + **BASE_PYTHON_TOOLS.copy(), + **self.toolbox.tools, + } + if self.managed_agents is not None: + static_tools = {**static_tools, **self.managed_agents} result = self.python_evaluator( code_action, - static_tools={ - **BASE_PYTHON_TOOLS.copy(), - **self.toolbox.tools, - }, + static_tools=static_tools, custom_tools=self.custom_tools, state=self.state, authorized_imports=self.authorized_imports, ) - information = self.state["print_outputs"] self.logger.warning("Print outputs:") - self.logger.log(32, information) - current_step_logs["observation"] = information + self.logger.log(32, self.state["print_outputs"]) + if result is not None: + self.logger.warning("Last output from code snippet:") + self.logger.log(32, str(result)) + observation = "Print outputs:\n" + self.state["print_outputs"] + if result is not None: + observation += "Last output from code snippet:\n" + str(result)[:100000] + current_step_logs["observation"] = observation except Exception as e: error_msg = f"Code execution failed due to the following error:\n{str(e)}" if "'dict' object has no attribute 'read'" in str(e): @@ -1095,7 +1159,57 @@ def step(self): raise AgentExecutionError(error_msg) for line in code_action.split("\n"): if line[: len("final_answer")] == "final_answer": - self.logger.warning(">>> Final answer:") + self.logger.log(33, "Final answer:") self.logger.log(32, result) current_step_logs["final_answer"] = result return current_step_logs + + +class ManagedAgent: + def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False): + self.agent = agent + self.name = name + self.description = description + self.additional_prompting = additional_prompting + self.provide_run_summary = provide_run_summary + + def write_full_task(self, task): + full_task = f"""You're a helpful agent named '{self.name}'. +You have been submitted this task by your manager. +--- +Task: +{task} +--- +You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer. + +Your final_answer WILL HAVE to contain these parts: +### 1. Task outcome (short version): +### 2. Task outcome (extremely detailed version): +### 3. Additional context (if relevant): + +Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost. +And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. +<>""" + if self.additional_prompting: + full_task = full_task.replace("\n<>", self.additional_prompting).strip() + else: + full_task = full_task.replace("\n<>", "").strip() + return full_task + + def __call__(self, request, **kwargs): + full_task = self.write_full_task(request) + output = self.agent.run(full_task, **kwargs) + if self.provide_run_summary: + answer = f"Here is the final answer from your managed agent '{self.name}':\n" + answer += str(output) + answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" + for message in self.agent.write_inner_memory_from_logs(summary_mode=True): + content = message["content"] + if len(str(content)) < 1000 or "[FACTS LIST]" in str(content): + answer += "\n" + str(content) + "\n---" + else: + answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---" + answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'." + return answer + else: + return output diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 84bbf3a9738667..b02b12d5287cec 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -29,7 +29,7 @@ def custom_print(*args): - return " ".join(map(str, args)) + return None BASE_PYTHON_TOOLS = { diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index bbc674adc231e4..de8ad1d2849013 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -332,10 +332,10 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): --- Task: "What is the current age of the pope, raised to the power 0.36?" -Thought: I will use the tool `search` to get the age of the pope, then raise it to the power 0.36. +Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36. Code: ```py -pope_age = search(query="current pope age") +pope_age = wiki(query="current pope age") print("Pope age:", pope_age) ``` Observation: @@ -348,16 +348,16 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): final_answer(pope_current_age) ``` -Above example were using notional tools that might not exist for you. You only have acces to those tools: +Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool): <> -You also can perform computations in the Python code that you generate. +<> Here are the rules you should always follow to solve your task: 1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```' sequence, else you will fail. 2. Use only variables that you have defined! -3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. +3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'. 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. @@ -410,6 +410,8 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): Your plan can leverage any of these tools: {tool_descriptions} +{managed_agents_descriptions} + List of facts that you know: ``` {answer_facts} @@ -453,9 +455,11 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): {task} ``` -You have access to these tools: +You have access to these tools and only these: {tool_descriptions} +{managed_agents_descriptions} + Here is the up to date list of facts that you know: ``` {facts_update} diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 702363a21e483b..fbece2bebd350f 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -434,7 +434,7 @@ def evaluate_call(call, state, static_tools, custom_tools): global PRINT_OUTPUTS PRINT_OUTPUTS += output + "\n" # cap the number of lines - return output + return None else: # Assume it's a callable object output = func(*args, **kwargs) return output diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index c18a568fdf78d0..67cb31b7dac334 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -20,7 +20,14 @@ import pytest from transformers.agents.agent_types import AgentText -from transformers.agents.agents import AgentMaxIterationsError, CodeAgent, ReactCodeAgent, ReactJsonAgent, Toolbox +from transformers.agents.agents import ( + AgentMaxIterationsError, + CodeAgent, + ManagedAgent, + ReactCodeAgent, + ReactJsonAgent, + Toolbox, +) from transformers.agents.default_tools import PythonInterpreterTool from transformers.testing_utils import require_torch @@ -235,3 +242,19 @@ def test_function_persistence_across_steps(self): ) res = agent.run("ok") assert res[0] == 0.5 + + def test_init_managed_agent(self): + agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef) + managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") + assert managed_agent.name == "managed_agent" + assert managed_agent.description == "Empty" + + def test_agent_description_gets_correctly_inserted_in_system_prompt(self): + agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef) + managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") + manager_agent = ReactCodeAgent( + tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent] + ) + assert "You can also give requests to team members." not in agent.system_prompt + assert "<>" not in agent.system_prompt + assert "You can also give requests to team members." in manager_agent.system_prompt