From ad87b0d066544b1415dab300d05e3464ea2500d9 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Thu, 23 May 2024 20:35:31 +0200 Subject: [PATCH] Fix tuples in for loop in python interpreter --- src/transformers/agents/agents.py | 2 +- src/transformers/agents/prompts.py | 11 ++++++----- src/transformers/agents/python_interpreter.py | 8 ++++++-- tests/agents/test_python_interpreter.py | 5 +++++ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 2e4042804e2d32..81e473d08669de 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -617,7 +617,7 @@ def provide_final_answer(self, task) -> str: } ] try: - return self.llm_engine(self.prompt, stop_sequences=["", "Observation:"]) + return self.llm_engine(self.prompt) except Exception as e: return f"Error in generating final llm output: {e}." diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index 829b6dcf0723e4..a3e7873e4c8234 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -350,15 +350,16 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): <> -You also can perform computations in the python code you generate. +You also can perform computations in the python code that you generate. Here are the rules you should always follow to solve your task: 1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```' sequence, else you will fail. -2. Make sure the variable you use are all defined. +2. Use only variables that you defined! 3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. -4. Do not perform too many operations in a single code block. Split the task into intermediate code blocks. Then use print() to save the intermediate result. Finally, use final_answer() to return the final result. -5. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. -6. Never re-do a tool call that you previously did with the exact same parameters. +4. In the code you generate, you can do several tool calls in parallel, but do not perform too many tool calls sequentially in a single block, especialy when the output of one tool call is in an unpredictable format: in that case rather split the task into intermediate code blocks, then use print() to save the intermediate result. Finally, use final_answer() to return the final result. +For instance if you need to call search for three independant topics, you can do three calls in the same code block. But if you're waiting for the ouput of one search call to generate an image, rather print the search output, then generate the image in the next step. +5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. +6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 8ca1cd182095d8..b64f8760208ccf 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -29,7 +29,7 @@ class InterpretorError(ValueError): pass -LIST_SAFE_MODULES = ["random", "math", "time", "queue", "itertools", "re", "stat", "statistics", "unicodedata"] +LIST_SAFE_MODULES = ["random", "collections", "requests", "math", "time", "queue", "itertools", "re", "stat", "statistics", "unicodedata"] class BreakException(Exception): @@ -307,7 +307,11 @@ def evaluate_for(for_loop, state, tools): result = None iterator = evaluate_ast(for_loop.iter, state, tools) for counter in iterator: - state[for_loop.target.id] = counter + if isinstance(for_loop.target, ast.Tuple): + for i, elem in enumerate(for_loop.target.elts): + state[elem.id] = counter[i] + else: + state[for_loop.target.id] = counter for node in for_loop.body: try: line_result = evaluate_ast(node, state, tools) diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 7a3257494fb7a4..3e02133476a553 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -353,3 +353,8 @@ def test_print_output(self): result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) assert result == "Ok no one cares" assert state["print_outputs"] == "Hello world!\nOk no one cares\n" + + def test_tuple_target_in_iterator(self): + code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert result == "Samuel"