From fdc9d6cc430496c32d67cf6f02402d86c477c2e3 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Fri, 31 May 2024 14:16:23 +0200 Subject: [PATCH] Add streaming, various fixes (#30838) * Implement streaming run in ReAct agents * Allow additional imports in code agents * Python interpreter: support classes and exceptions, fixes --- docs/source/en/agents.md | 34 ++- src/transformers/agents/agents.py | 157 ++++++++---- src/transformers/agents/llm_engine.py | 11 +- src/transformers/agents/prompts.py | 106 ++++---- src/transformers/agents/python_interpreter.py | 241 +++++++++++++++--- src/transformers/agents/tools.py | 4 +- tests/agents/test_python_interpreter.py | 128 ++++++++++ 7 files changed, 522 insertions(+), 159 deletions(-) diff --git a/docs/source/en/agents.md b/docs/source/en/agents.md index ae9e5db2b7897b..2cacaed5902c4d 100644 --- a/docs/source/en/agents.md +++ b/docs/source/en/agents.md @@ -28,8 +28,8 @@ An agent is a system that uses an LLM as its engine, and it has access to functi These *tools* are functions for performing a task, and they contain all necessary description for the agent to properly use them. The agent can be programmed to: -- devise a series of actions/tools and run them all at once like the `CodeAgent` for example -- plan and execute actions/tools one by one and wait for the outcome of each action before launching the next one like the `ReactJsonAgent` for example +- devise a series of actions/tools and run them all at once like the [`CodeAgent`] for example +- plan and execute actions/tools one by one and wait for the outcome of each action before launching the next one like the [`ReactJsonAgent`] for example ### Types of agents @@ -42,8 +42,8 @@ This agent has a planning step, then generates python code to execute all its ac This is the go-to agent to solve reasoning tasks, since the ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) makes it really efficient to think on the basis of its previous observations. We implement two versions of ReactJsonAgent: -- [`~ReactJsonAgent`] generates tool calls as a JSON in its output. -- [`~ReactCodeAgent`] is a new type of ReactJsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance. +- [`ReactJsonAgent`] generates tool calls as a JSON in its output. +- [`ReactCodeAgent`] is a new type of ReactJsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance. > [!TIP] > Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more the ReAct agent. @@ -124,7 +124,7 @@ You could use any `llm_engine` method as long as: You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`. -Now you can create an agent, like `CodeAgent`, and run it. For convenience, we also provide the `HfEngine` class that uses `huggingface_hub.InferenceClient` under the hood. +Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood. ```python from transformers import CodeAgent, HfEngine @@ -139,7 +139,7 @@ agent.run( ``` This will be handy in case of emergency baguette need! -You can even leave the argument `llm_engine` undefined, and an [~HfEngine] will be created by default. +You can even leave the argument `llm_engine` undefined, and an [`HfEngine`] will be created by default. ```python from transformers import CodeAgent @@ -181,13 +181,27 @@ You can also run an agent consecutively for different tasks: each time the attri A Python interpreter executes the code on a set of inputs passed along with your tools. This should be safe because the only functions that can be called are the tools you provided (especially if it's only tools by Hugging Face) and the print function, so you're already limited in what can be executed. -The Python interpreter also doesn't allow any attribute lookup or imports (which shouldn't be needed for passing inputs/outputs to a small set of functions) so all the most obvious attacks shouldn't be an issue. +The Python interpreter also doesn't allow imports by default outside of a safe list, so all the most obvious attacks shouldn't be an issue. +You can still authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`ReactCodeAgent`] or [`CodeAgent`]: + +```py +>>> from transformers import ReactCodeAgent + +>>> agent = ReactCodeAgent(tools=[], additional_authorized_imports=['requests', 'bs4']) +>>>agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?") + +(...) +'Hugging Face – Blog' +``` The execution will stop at any code trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent. +> [!WARNING] +> The LLM can generate arbitrary code that will then be executed: do not add any unsafe imports! + ### The system prompt -An agent, or rather the LLM that drives the agent, generates an output based on the system prompt. The system prompt can be customized and tailored to the intended task. For example, check the system prompt for the `ReactCodeAgent` (below version is slightly simplified). +An agent, or rather the LLM that drives the agent, generates an output based on the system prompt. The system prompt can be customized and tailored to the intended task. For example, check the system prompt for the [`ReactCodeAgent`] (below version is slightly simplified). ```text You will be given a task to solve as best you can. @@ -246,7 +260,7 @@ of the available tools. A tool is an atomic function to be used by an agent. -You can for instance check the [~PythonInterpreterTool]: it has a name, a description, input descriptions, an output type, and a `__call__` method to perform the action. +You can for instance check the [`PythonInterpreterTool`]: it has a name, a description, input descriptions, an output type, and a `__call__` method to perform the action. When the agent is initialized, the tool attributes are used to generate a tool description which is baked into the agent's system prompt. This lets the agent know which tools it can use and why. @@ -259,7 +273,7 @@ Transformers comes with a default toolbox for empowering agents, that you can ad - **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper)) - **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5)) - **Translation**: translates a given sentence from source language to target language. -- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [~ReactJsonAgent] if you use `add_base_tools=True`, since code-based tools can already execute Python code +- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you use `add_base_tools=True`, since code-based tools can already execute Python code You can manually use a tool by calling the [`load_tool`] function and a task to perform. diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 64e810eb91f8b6..ad0b9fecc3ef18 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -26,7 +26,7 @@ from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfEngine, MessageRole from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT -from .python_interpreter import evaluate_python_code +from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .tools import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, Tool, @@ -84,8 +84,14 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: return json_data except json.JSONDecodeError as e: place = e.pos + if json_blob[place - 1 : place + 2] == "},\n": + raise ValueError( + "JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL." + ) raise ValueError( - f"The JSON blob you used is invalid: due to the following error: {e}. JSON blob was: {json_blob}, decoding failed at '{json_blob[place-4:place+5]}'." + f"The JSON blob you used is invalid due to the following error: {e}.\n" + f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n" + f"'{json_blob[place-4:place+5]}'." ) except Exception as e: raise ValueError(f"Error in parsing the JSON blob: {e}") @@ -347,6 +353,7 @@ def toolbox(self) -> Toolbox: return self._toolbox def initialize_for_run(self, task: str, **kwargs): + self.token_count = 0 self.task = task if len(kwargs) > 0: self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." @@ -380,7 +387,7 @@ def write_inner_memory_from_logs(self) -> List[Dict[str, str]]: message_content = ( "Error: " + str(step_log["error"]) - + "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches.\n" + + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" ) elif "observation" in step_log: message_content = f"Observation: {step_log['observation']}" @@ -409,6 +416,9 @@ def write_inner_memory_from_logs(self) -> List[Dict[str, str]]: ) return memory + def get_succinct_logs(self): + return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs] + def extract_action(self, llm_output: str, split_token: str) -> str: """ Parse action from the LLM output @@ -486,6 +496,7 @@ def __init__( llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + additional_authorized_imports: List[str] = [], **kwargs, ): super().__init__( @@ -504,6 +515,7 @@ def __init__( ) self.python_evaluator = evaluate_python_code + self.additional_authorized_imports = additional_authorized_imports def parse_code_blob(self, result: str) -> str: """ @@ -544,7 +556,7 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): self.prompt = [prompt_message, task_message] self.logger.info("====Executing with this prompt====") self.logger.info(self.prompt) - llm_output = self.llm_engine(self.prompt, stop_sequences=[""]) + llm_output = self.llm_engine(self.prompt, stop_sequences=[""]) if return_generated_code: return llm_output @@ -563,7 +575,12 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): self.log_code_action(code_action) try: available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} - output = self.python_evaluator(code_action, available_tools, state=self.state) + output = self.python_evaluator( + code_action, + available_tools, + state=self.state, + authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, + ) self.logger.info(self.state["print_outputs"]) return output except Exception as e: @@ -597,7 +614,29 @@ def __init__( if "final_answer" not in self._toolbox.tools: self._toolbox.add_tool(FinalAnswerTool()) - def run(self, task: str, **kwargs): + def provide_final_answer(self, task) -> str: + """ + This method provides a final answer to the task, based on the logs of the agent's interactions. + """ + self.prompt = [ + { + "role": MessageRole.SYSTEM, + "content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", + } + ] + self.prompt += self.write_inner_memory_from_logs()[1:] + self.prompt += [ + { + "role": MessageRole.USER, + "content": f"Based on the above, please provide an answer to the following user request:\n{task}", + } + ] + try: + return self.llm_engine(self.prompt) + except Exception as e: + return f"Error in generating final llm output: {e}." + + def run(self, task: str, stream: bool = False, **kwargs): """ Runs the agent for the given task. @@ -614,41 +653,62 @@ def run(self, task: str, **kwargs): agent.run("What is the result of 2 power 3.7384?") ``` """ + if stream: + return self.stream_run(task, **kwargs) + else: + return self.direct_run(task, **kwargs) + + def stream_run(self, task: str, **kwargs): self.initialize_for_run(task, **kwargs) final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: try: - final_answer = self.step() + step_logs = self.step() + if "final_answer" in step_logs: + final_answer = step_logs["final_answer"] except AgentError as e: self.logger.error(e, exc_info=1) self.logs[-1]["error"] = e finally: iteration += 1 + yield self.logs[-1] if final_answer is None and iteration == self.max_iterations: error_message = "Reached max iterations." - self.logs.append({"error": AgentMaxIterationsError(error_message)}) + final_step_log = {"error": AgentMaxIterationsError(error_message)} + self.logs.append(final_step_log) self.logger.error(error_message, exc_info=1) + final_answer = self.provide_final_answer(task) + final_step_log["final_answer"] = final_answer + yield final_step_log + + yield final_answer - self.prompt = [ - { - "role": MessageRole.SYSTEM, - "content": "An agent tried to answer a user query but it failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", - } - ] - self.prompt += self.write_inner_memory_from_logs()[1:] - self.prompt += [ - { - "role": MessageRole.USER, - "content": f"Based on the above, please provide an answer to the following user request:\n{task}", - } - ] + def direct_run(self, task: str, **kwargs): + self.initialize_for_run(task, **kwargs) + + final_answer = None + iteration = 0 + while final_answer is None and iteration < self.max_iterations: try: - final_answer = self.llm_engine(self.prompt, stop_sequences=["Observation:"]) - except Exception as e: - final_answer = f"Error in generating final llm output: {e}." + step_logs = self.step() + if "final_answer" in step_logs: + final_answer = step_logs["final_answer"] + except AgentError as e: + self.logger.error(e, exc_info=1) + self.logs[-1]["error"] = e + finally: + iteration += 1 + + if final_answer is None and iteration == self.max_iterations: + error_message = "Reached max iterations." + final_step_log = {"error": AgentMaxIterationsError(error_message)} + self.logs.append(final_step_log) + self.logger.error(error_message, exc_info=1) + final_answer = self.provide_final_answer(task) + final_step_log["final_answer"] = final_answer return final_answer @@ -683,22 +743,24 @@ def step(self): """ agent_memory = self.write_inner_memory_from_logs() - self.logs[-1]["agent_memory"] = agent_memory.copy() self.prompt = agent_memory self.logger.debug("===== New step =====") # Add new step in logs - self.logs.append({}) + current_step_logs = {} + self.logs.append(current_step_logs) + current_step_logs["agent_memory"] = agent_memory.copy() + self.logger.info("===== Calling LLM with this last message: =====") self.logger.info(self.prompt[-1]) try: - llm_output = self.llm_engine(self.prompt, stop_sequences=["Observation:"]) + llm_output = self.llm_engine(self.prompt, stop_sequences=["", "Observation:"]) except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") self.logger.debug("===== Output message of the LLM: =====") self.logger.debug(llm_output) - self.logs[-1]["llm_output"] = llm_output + current_step_logs["llm_output"] = llm_output # Parse self.logger.debug("===== Extracting action =====") @@ -709,8 +771,8 @@ def step(self): except Exception as e: raise AgentParsingError(f"Could not parse the given action: {e}.") - self.logs[-1]["rationale"] = rationale - self.logs[-1]["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} + current_step_logs["rationale"] = rationale + current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} # Execute self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}") @@ -721,7 +783,8 @@ def step(self): answer = arguments if answer in self.state: # if the answer is a state variable, return the value answer = self.state[answer] - return answer + current_step_logs["final_answer"] = answer + return current_step_logs else: observation = self.execute_tool_call(tool_name, arguments) observation_type = type(observation) @@ -740,8 +803,8 @@ def step(self): updated_information = f"Stored '{observation_name}' in memory." self.logger.info(updated_information) - self.logs[-1]["observation"] = updated_information - return None + current_step_logs["observation"] = updated_information + return current_step_logs class ReactCodeAgent(ReactAgent): @@ -757,6 +820,7 @@ def __init__( llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + additional_authorized_imports: List[str] = [], **kwargs, ): super().__init__( @@ -775,6 +839,7 @@ def __init__( ) self.python_evaluator = evaluate_python_code + self.additional_authorized_imports = additional_authorized_imports def step(self): """ @@ -782,26 +847,27 @@ def step(self): The errors are raised here, they are caught and logged in the run() method. """ agent_memory = self.write_inner_memory_from_logs() - self.logs[-1]["agent_memory"] = agent_memory.copy() self.prompt = agent_memory.copy() self.logger.debug("===== New step =====") # Add new step in logs - self.logs.append({}) + current_step_logs = {} + self.logs.append(current_step_logs) + current_step_logs["agent_memory"] = agent_memory.copy() self.logger.info("===== Calling LLM with these last messages: =====") self.logger.info(self.prompt[-2:]) try: - llm_output = self.llm_engine(self.prompt, stop_sequences=["", "Observation:"]) + llm_output = self.llm_engine(self.prompt, stop_sequences=["", "Observation:"]) except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") self.logger.debug("===== Output message of the LLM: =====") self.logger.debug(llm_output) - self.logs[-1]["llm_output"] = llm_output + current_step_logs["llm_output"] = llm_output # Parse self.logger.debug("===== Extracting action =====") @@ -813,18 +879,23 @@ def step(self): error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" raise AgentParsingError(error_msg) - self.logs[-1]["rationale"] = rationale - self.logs[-1]["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} + current_step_logs["rationale"] = rationale + current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} # Execute self.log_code_action(code_action) try: available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} - result = self.python_evaluator(code_action, available_tools, state=self.state) + result = self.python_evaluator( + code_action, + available_tools, + state=self.state, + authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, + ) information = self.state["print_outputs"] self.logger.warning("Print outputs:") self.logger.log(32, information) - self.logs[-1]["observation"] = information + current_step_logs["observation"] = information except Exception as e: error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}" if "'dict' object has no attribute 'read'" in str(e): @@ -834,5 +905,5 @@ def step(self): if line[: len("final_answer")] == "final_answer": self.logger.warning(">>> Final answer:") self.logger.log(32, result) - return result - return None + current_step_logs["final_answer"] = result + return current_step_logs diff --git a/src/transformers/agents/llm_engine.py b/src/transformers/agents/llm_engine.py index b696084090c001..76458b02677dbb 100644 --- a/src/transformers/agents/llm_engine.py +++ b/src/transformers/agents/llm_engine.py @@ -61,7 +61,6 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: llama_role_conversions = { - MessageRole.SYSTEM: MessageRole.USER, MessageRole.TOOL_RESPONSE: MessageRole.USER, } @@ -72,20 +71,14 @@ def __init__(self, model: str = "meta-llama/Meta-Llama-3-8B-Instruct"): self.client = InferenceClient(model=self.model, timeout=120) def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str: - if "Meta-Llama-3" in self.model: - if "<|eot_id|>" not in stop_sequences: - stop_sequences.append("<|eot_id|>") - if "!!!!!" not in stop_sequences: - stop_sequences.append("!!!!!") - # Get clean message list messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) - # Get answer + # Get LLM output response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500) response = response.choices[0].message.content - # Remove stop sequences from the answer + # Remove stop sequences from LLM output for stop_seq in stop_sequences: if response[-len(stop_seq) :] == stop_seq: response = response[: -len(stop_seq)] diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index 80c65a5144027d..f76734b09fa88f 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -68,7 +68,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): print(f"The translated question is {translated_question}.") answer = image_qa(image=image, question=translated_question) print(f"The answer is {answer}") -``` +``` --- Task: "Identify the oldest person in the `document` and create an image showcasing the result." @@ -79,7 +79,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): answer = document_qa(document, question="What is the oldest person?") print(f"The answer is {answer}.") image = image_generator(answer) -``` +``` --- Task: "Generate an image using the text given in the variable `caption`." @@ -88,7 +88,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): Code: ```py image = image_generator(prompt=caption) -``` +``` --- Task: "Summarize the text given in the variable `text` and read it out loud." @@ -99,7 +99,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): summarized_text = summarizer(text) print(f"Summary: {summarized_text}") audio_summary = text_reader(summarized_text) -``` +``` --- Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image." @@ -110,7 +110,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): answer = text_qa(text=text, question=question) print(f"The answer is {answer}.") image = image_generator(answer) -``` +``` --- Task: "Caption the following `image`." @@ -119,39 +119,32 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): Code: ```py caption = image_captioner(image) -``` +``` --- Above example were using tools that might not exist for you. You only have acces to those Tools: <> Remember to make sure that variables you use are all defined. -Be sure to provide a 'Code:\n```' sequence before the code and '```' after, else you will get an error. +Be sure to provide a 'Code:\n```' sequence before the code and '```' after, else you will get an error. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. -Now Begin! +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ -DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. You have access to the following tools: -<> - -The way you use the tools is by specifying a json blob. -Specifically, this json should have a `action` key (name of the tool to use) and a `action_input` key (input to the tool). +DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. To do so, you have been given access to the following tools: <> +The way you use the tools is by specifying a json blob, ending with ''. +Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool). The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB: -Action: { "action": $TOOL_NAME, "action_input": $INPUT -} +} Make sure to have the $INPUT as a dictionnary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. -You will be given: - -Task: the task you are given. - You should ALWAYS use the following format: Thought: you should always think about one action to take. Then use the action as follows: @@ -171,14 +164,14 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "image_transformer", "action_input": {"image": "image_1.jpg"} -} +} To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: Action: { "action": "final_answer", "action_input": {"answer": "insert your final answer here"} -} +} Here are a few examples using notional tools: @@ -190,7 +183,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "document_qa", "action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} -} +} Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." @@ -199,7 +192,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "image_generator", "action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""} -} +} Observation: "image.png" Thought: I will now return the generated image. @@ -207,7 +200,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "final_answer", "action_input": "image.png" -} +} --- Task: "What is the result of the following operation: 5 + 3 + 1294.678?" @@ -217,7 +210,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "python_interpreter", "action_input": {"code": "5 + 3 + 1294.678"} -} +} Observation: 1302.678 Thought: Now that I know the result, I will now return it. @@ -225,7 +218,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "final_answer", "action_input": "1302.678" -} +} --- Task: "Which city has the highest population , Guangzhou or Shanghai?" @@ -235,7 +228,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "search", "action_input": "Population Guangzhou" -} +} Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] @@ -252,28 +245,30 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): { "action": "final_answer", "action_input": "Shanghai" -} +} Above example were using notional tools that might not exist for you. You only have acces to those tools: -<> -ALWAYS provide a 'Thought:' and an 'Action:' sequence. You MUST provide at least the 'Action:' sequence to move forward. +<> + +Here are the rules you should always follow to solve your task: +1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with , else you will fail. +2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead. +3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +4. Never re-do a tool call that you previously did with the exact same parameters. -Now begin! +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can. -You have access to the following tools: -<> - +To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code. To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. -At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task, then the tools that you want to use. -Then in the 'Code:' sequence, you shold write the code in simple Python. The code sequence must end with '/End code' sequence. +At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use. +Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '' sequence. During each intermediate step, you can use 'print()' to save whatever important information you will then need. -These print outputs will then be available in the 'Observation:' field, for using this information as input for the next step. - +These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step. In the end you have to return a final answer using the `final_answer` tool. Here are a few examples using notional tools: @@ -285,7 +280,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): ```py answer = document_qa(document=document, question="Who is the oldest person mentioned?") print(answer) -``` +``` Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." Thought: I will now generate an image showcasing the oldest person. @@ -294,7 +289,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): ```py image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.") final_answer(image) -``` +``` --- Task: "What is the result of the following operation: 5 + 3 + 1294.678?" @@ -305,10 +300,10 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): ```py result = 5 + 3 + 1294.678 final_answer(result) -``` +``` --- -Task: "Which city has the highest population , Guangzhou or Shanghai?" +Task: "Which city has the highest population: Guangzhou or Shanghai?" Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. Code: @@ -317,7 +312,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): print("Population Guangzhou:", population_guangzhou) population_shanghai = search("Shanghai population") print("Population Shanghai:", population_shanghai) -``` +``` Observation: Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] Population Shanghai: '26 million (2019)' @@ -326,7 +321,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): Code: ```py final_answer("Shanghai") -``` +``` --- Task: "What is the current age of the pope, raised to the power 0.36?" @@ -336,7 +331,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): ```py pope_age = search(query="current pope age") print("Pope age:", pope_age) -``` +``` Observation: Pope age: "The pope Francis is currently 85 years old." @@ -345,20 +340,21 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): ```py pope_current_age = 85 ** 0.36 final_answer(pope_current_age) -``` - +``` Above example were using notional tools that might not exist for you. You only have acces to those tools: -<> -You also can perform computations in the python code you generate. -Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```' sequence. You MUST provide at least the 'Code:' sequence to move forward. +<> -Remember to not perform too many operations in a single code block! You should split the task into intermediate code blocks. -Print results at the end of each step to save the intermediate results. Then use final_answer() to return the final result. +You also can perform computations in the Python code that you generate. -Remember to make sure that variables you use are all defined. -DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. +Here are the rules you should always follow to solve your task: +1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```' sequence, else you will fail. +2. Use only variables that you have defined! +3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. +4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. +5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. +6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. -Now Begin! +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 8ca1cd182095d8..992e9d14f19a6b 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -15,9 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import ast +import builtins import difflib from collections.abc import Mapping -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional class InterpretorError(ValueError): @@ -29,7 +30,25 @@ class InterpretorError(ValueError): pass -LIST_SAFE_MODULES = ["random", "math", "time", "queue", "itertools", "re", "stat", "statistics", "unicodedata"] +ERRORS = { + name: getattr(builtins, name) + for name in dir(builtins) + if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) +} + + +LIST_SAFE_MODULES = [ + "random", + "collections", + "math", + "time", + "queue", + "itertools", + "re", + "stat", + "statistics", + "unicodedata", +] class BreakException(Exception): @@ -87,21 +106,62 @@ def evaluate_while(while_loop, state, tools): return None -def evaluate_function_def(function_def, state, tools): - def create_function(func_def, state, tools): - def new_func(*args): - new_state = state.copy() - for arg, val in zip(func_def.args.args, args): - new_state[arg.arg] = val - result = None - for node in func_def.body: - result = evaluate_ast(node, new_state, tools) - return result +def create_function(func_def, state, tools): + def new_func(*args, **kwargs): + func_state = state.copy() + arg_names = [arg.arg for arg in func_def.args.args] + for name, value in zip(arg_names, args): + func_state[name] = value + if func_def.args.vararg: + vararg_name = func_def.args.vararg.arg + func_state[vararg_name] = args + if func_def.args.kwarg: + kwarg_name = func_def.args.kwarg.arg + func_state[kwarg_name] = kwargs + + # Update function state with self and __class__ + if func_def.args.args and func_def.args.args[0].arg == "self": + if args: + func_state["self"] = args[0] + func_state["__class__"] = args[0].__class__ + + result = None + for stmt in func_def.body: + result = evaluate_ast(stmt, func_state, tools) + return result - return new_func + return new_func - tools[function_def.name] = create_function(function_def, state, tools) - return None + +def create_class(class_name, class_bases, class_body): + class_dict = {} + for key, value in class_body.items(): + class_dict[key] = value + return type(class_name, tuple(class_bases), class_dict) + + +def evaluate_function_def(func_def, state, tools): + tools[func_def.name] = create_function(func_def, state, tools) + return tools[func_def.name] + + +def evaluate_class_def(class_def, state, tools): + class_name = class_def.name + bases = [evaluate_ast(base, state, tools) for base in class_def.bases] + class_dict = {} + + for stmt in class_def.body: + if isinstance(stmt, ast.FunctionDef): + class_dict[stmt.name] = evaluate_function_def(stmt, state, tools) + elif isinstance(stmt, ast.Assign): + for target in stmt.targets: + class_dict[target.id] = evaluate_ast(stmt.value, state, tools) + else: + raise InterpretorError(f"Unsupported statement in class body: {stmt.__class__.__name__}") + + new_class = type(class_name, tuple(bases), class_dict) + state[class_name] = new_class + return new_class def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]): @@ -176,11 +236,20 @@ def evaluate_assign(assign, state, tools): var_names = assign.targets result = evaluate_ast(assign.value, state, tools) if len(var_names) == 1: - if isinstance(var_names[0], ast.Tuple): - for i, elem in enumerate(var_names[0].elts): + target = var_names[0] + if isinstance(target, ast.Tuple): + for i, elem in enumerate(target.elts): state[elem.id] = result[i] + elif isinstance(target, ast.Attribute): + obj = evaluate_ast(target.value, state, tools) + setattr(obj, target.attr, result) + elif isinstance(target, ast.Subscript): + obj = evaluate_ast(target.value, state, tools) + key = evaluate_ast(target.slice, state, tools) + obj[key] = result else: - state[var_names[0].id] = result + state[target.id] = result + else: if len(result) != len(var_names): raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.") @@ -190,41 +259,64 @@ def evaluate_assign(assign, state, tools): def evaluate_call(call, state, tools): + if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): + raise InterpretorError( + f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})." + ) if isinstance(call.func, ast.Attribute): obj = evaluate_ast(call.func.value, state, tools) func_name = call.func.attr if not hasattr(obj, func_name): raise InterpretorError(f"Object {obj} has no attribute {func_name}") func = getattr(obj, func_name) - args = [evaluate_ast(arg, state, tools) for arg in call.args] - kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} - return func(*args, **kwargs) - elif isinstance(call.func, ast.Name): func_name = call.func.id - if func_name in state: func = state[func_name] elif func_name in tools: func = tools[func_name] + elif func_name in ERRORS: + func = ERRORS[func_name] else: raise InterpretorError( f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})." ) - # Todo deal with args - args = [evaluate_ast(arg, state, tools) for arg in call.args] - kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} - output = func(*args, **kwargs) - # store logs of print statements - if func_name == "print": - state["print_outputs"] += output + "\n" + args = [evaluate_ast(arg, state, tools) for arg in call.args] + kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} - return output + if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes + # Instantiate the class using its constructor + obj = func.__new__(func) # Create a new instance of the class + if hasattr(obj, "__init__"): # Check if the class has an __init__ method + obj.__init__(*args, **kwargs) # Call the __init__ method correctly + return obj else: - raise InterpretorError( - f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})." - ) + if func_name == "super": + if not args: + if "__class__" in state and "self" in state: + return super(state["__class__"], state["self"]) + else: + raise InterpretorError("super() needs at least one argument") + cls = args[0] + if not isinstance(cls, type): + raise InterpretorError("super() argument 1 must be type") + if len(args) == 1: + return super(cls) + elif len(args) == 2: + instance = args[1] + return super(cls, instance) + else: + raise InterpretorError("super() takes at most 2 arguments") + + else: + if func_name == "print": + output = " ".join(map(str, args)) + state["print_outputs"] += output + "\n" + return output + else: # Assume it's a callable object + output = func(*args, **kwargs) + return output def evaluate_subscript(subscript, state, tools): @@ -248,6 +340,10 @@ def evaluate_subscript(subscript, state, tools): def evaluate_name(name, state, tools): if name.id in state: return state[name.id] + elif name.id in tools: + return tools[name.id] + elif name.id in ERRORS: + return ERRORS[name.id] close_matches = difflib.get_close_matches(name.id, list(state.keys())) if len(close_matches) > 0: return state[close_matches[0]] @@ -307,7 +403,11 @@ def evaluate_for(for_loop, state, tools): result = None iterator = evaluate_ast(for_loop.iter, state, tools) for counter in iterator: - state[for_loop.target.id] = counter + if isinstance(for_loop.target, ast.Tuple): + for i, elem in enumerate(for_loop.target.elts): + state[elem.id] = counter[i] + else: + state[for_loop.target.id] = counter for node in for_loop.body: try: line_result = evaluate_ast(node, state, tools) @@ -337,7 +437,56 @@ def evaluate_listcomp(listcomp, state, tools): return result -def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]): +def evaluate_try(try_node, state, tools): + try: + for stmt in try_node.body: + evaluate_ast(stmt, state, tools) + except Exception as e: + matched = False + for handler in try_node.handlers: + if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)): + matched = True + if handler.name: + state[handler.name] = e + for stmt in handler.body: + evaluate_ast(stmt, state, tools) + break + if not matched: + raise e + else: + if try_node.orelse: + for stmt in try_node.orelse: + evaluate_ast(stmt, state, tools) + finally: + if try_node.finalbody: + for stmt in try_node.finalbody: + evaluate_ast(stmt, state, tools) + + +def evaluate_raise(raise_node, state, tools): + if raise_node.exc is not None: + exc = evaluate_ast(raise_node.exc, state, tools) + else: + exc = None + if raise_node.cause is not None: + cause = evaluate_ast(raise_node.cause, state, tools) + else: + cause = None + if exc is not None: + if cause is not None: + raise exc from cause + else: + raise exc + else: + raise InterpretorError("Re-raise is not supported without an active exception") + + +def evaluate_ast( + expression: ast.AST, + state: Dict[str, Any], + tools: Dict[str, Callable], + authorized_imports: List[str] = LIST_SAFE_MODULES, +): """ Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given set of functions. @@ -353,6 +502,9 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca tools (`Dict[str, Callable]`): The functions that may be called during the evaluation. Any call to another function will fail with an `InterpretorError`. + authorized_imports (`List[str]`): + The list of modules that can be imported by the code. By default, only a few safe modules are allowed. + Add more at your own risk! """ if isinstance(expression, ast.Assign): # Assignement -> we evaluate the assignement which should update the state @@ -459,7 +611,7 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca return result elif isinstance(expression, ast.Import): for alias in expression.names: - if alias.name in LIST_SAFE_MODULES: + if alias.name in authorized_imports: module = __import__(alias.name) state[alias.asname or alias.name] = module else: @@ -468,19 +620,27 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca elif isinstance(expression, ast.While): return evaluate_while(expression, state, tools) elif isinstance(expression, ast.ImportFrom): - if expression.module in LIST_SAFE_MODULES: + if expression.module in authorized_imports: module = __import__(expression.module) for alias in expression.names: state[alias.asname or alias.name] = getattr(module, alias.name) else: raise InterpretorError(f"Import from {expression.module} is not allowed.") return None + elif isinstance(expression, ast.ClassDef): + return evaluate_class_def(expression, state, tools) + elif isinstance(expression, ast.Try): + return evaluate_try(expression, state, tools) + elif isinstance(expression, ast.Raise): + return evaluate_raise(expression, state, tools) else: # For now we refuse anything else. Let's add things as we need them. raise InterpretorError(f"{expression.__class__.__name__} is not supported.") -def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, state=None): +def evaluate_python_code( + code: str, tools: Optional[Dict[str, Callable]] = {}, state=None, authorized_imports: List[str] = LIST_SAFE_MODULES +): """ Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set of functions. @@ -506,9 +666,10 @@ def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, s state = {} result = None state["print_outputs"] = "" + for idx, node in enumerate(expression.body): try: - line_result = evaluate_ast(node, state, tools) + line_result = evaluate_ast(node, state, tools, authorized_imports) except InterpretorError as e: msg = f"You tried to execute the following code:\n{code}\n" msg += f"You got these outputs:\n{state['print_outputs']}\n" diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 4016a20f81e441..36e88d5a06ffa8 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -185,7 +185,7 @@ def save(self, output_dir): "tool_class": full_name, "description": self.description, "name": self.name, - "inputs": str(self.inputs), + "inputs": self.inputs, "output_type": str(self.output_type), } with open(config_file, "w", encoding="utf-8") as f: @@ -315,7 +315,7 @@ def from_hub( if tool_class.output_type != custom_tool["output_type"]: tool_class.output_type = custom_tool["output_type"] - return tool_class(model_repo_id, token=token, **kwargs) + return tool_class(**kwargs) def push_to_hub( self, diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 7a3257494fb7a4..dbe6c90a9ea0f6 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -353,3 +353,131 @@ def test_print_output(self): result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) assert result == "Ok no one cares" assert state["print_outputs"] == "Hello world!\nOk no one cares\n" + + def test_tuple_target_in_iterator(self): + code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert result == "Samuel" + + def test_classes(self): + code = """ +class Animal: + species = "Generic Animal" + + def __init__(self, name, age): + self.name = name + self.age = age + + def sound(self): + return "The animal makes a sound." + + def __str__(self): + return f"{self.name}, {self.age} years old" + +class Dog(Animal): + species = "Canine" + + def __init__(self, name, age, breed): + super().__init__(name, age) + self.breed = breed + + def sound(self): + return "The dog barks." + + def __str__(self): + return f"{self.name}, {self.age} years old, {self.breed}" + +class Cat(Animal): + def sound(self): + return "The cat meows." + + def __str__(self): + return f"{self.name}, {self.age} years old, {self.species}" + + +# Testing multiple instances +dog1 = Dog("Fido", 3, "Labrador") +dog2 = Dog("Buddy", 5, "Golden Retriever") + +# Testing method with built-in function +animals = [dog1, dog2, Cat("Whiskers", 2)] +num_animals = len(animals) + +# Testing exceptions in methods +class ExceptionTest: + def method_that_raises(self): + raise ValueError("An error occurred") + +try: + exc_test = ExceptionTest() + exc_test.method_that_raises() +except ValueError as e: + exception_message = str(e) + + +# Collecting results +dog1_sound = dog1.sound() +dog1_str = str(dog1) +dog2_sound = dog2.sound() +dog2_str = str(dog2) +cat = Cat("Whiskers", 2) +cat_sound = cat.sound() +cat_str = str(cat) + """ + state = {} + evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) + + # Assert results + assert state["dog1_sound"] == "The dog barks." + assert state["dog1_str"] == "Fido, 3 years old, Labrador" + assert state["dog2_sound"] == "The dog barks." + assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever" + assert state["cat_sound"] == "The cat meows." + assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal" + assert state["num_animals"] == 3 + assert state["exception_message"] == "An error occurred" + + def test_variable_args(self): + code = """ +def var_args_method(self, *args, **kwargs): + return sum(args) + sum(kwargs.values()) + +var_args_method(1, 2, 3, x=4, y=5) +""" + state = {} + result = evaluate_python_code(code, {"sum": sum}, state=state) + assert result == 15 + + def test_exceptions(self): + code = """ +def method_that_raises(self): + raise ValueError("An error occurred") + +try: + method_that_raises() +except ValueError as e: + exception_message = str(e) + """ + state = {} + evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) + assert state["exception_message"] == "An error occurred" + + def test_subscript(self): + code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" + + state = {} + evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state) + assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} + + def test_print(self): + code = "print(min([1, 2, 3]))" + state = {} + result = evaluate_python_code(code, {"min": min, "print": print}, state=state) + assert result == "1" + assert state["print_outputs"] == "1\n" + + def test_types_as_objects(self): + code = "type_a = float(2); type_b = str; type_c = int" + state = {} + result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) + assert result == int