diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 9af07b49592ba0..a147c9d30ad1ff 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -26,7 +26,7 @@ from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfEngine, MessageRole from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT -from .python_interpreter import evaluate_python_code, LIST_SAFE_MODULES +from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .tools import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, Tool, @@ -411,10 +411,7 @@ def write_inner_memory_from_logs(self) -> List[Dict[str, str]]: return memory def get_succinct_logs(self): - return [ - {key: value for key, value in log.items() if key != "agent_memory"} - for log in self.logs - ] + return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs] def extract_action(self, llm_output: str, split_token: str) -> str: """ @@ -576,7 +573,7 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): code_action, available_tools, state=self.state, - authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports + authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, ) self.logger.info(self.state["print_outputs"]) return output @@ -887,7 +884,7 @@ def step(self): code_action, available_tools, state=self.state, - authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports + authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, ) information = self.state["print_outputs"] self.logger.warning("Print outputs:") diff --git a/src/transformers/agents/llm_engine.py b/src/transformers/agents/llm_engine.py index 3372bcd936ffe4..76458b02677dbb 100644 --- a/src/transformers/agents/llm_engine.py +++ b/src/transformers/agents/llm_engine.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from copy import deepcopy from enum import Enum from typing import Dict, List diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 5b8f5f37ce7214..cab31db16a432b 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import ast +import builtins import difflib from collections.abc import Mapping from typing import Any, Callable, Dict, List, Optional @@ -29,6 +30,13 @@ class InterpretorError(ValueError): pass +ERRORS = { + name: getattr(builtins, name) + for name in dir(builtins) + if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) +} + + LIST_SAFE_MODULES = [ "random", "collections", @@ -99,13 +107,27 @@ def evaluate_while(while_loop, state, tools): def create_function(func_def, state, tools): - def new_func(*args): - new_state = state.copy() - for arg, val in zip(func_def.args.args, args): - new_state[arg.arg] = val + def new_func(*args, **kwargs): + func_state = state.copy() + arg_names = [arg.arg for arg in func_def.args.args] + for name, value in zip(arg_names, args): + func_state[name] = value + if func_def.args.vararg: + vararg_name = func_def.args.vararg.arg + func_state[vararg_name] = args + if func_def.args.kwarg: + kwarg_name = func_def.args.kwarg.arg + func_state[kwarg_name] = kwargs + + # Update function state with self and __class__ + if func_def.args.args and func_def.args.args[0].arg == "self": + if args: + func_state["self"] = args[0] + func_state["__class__"] = args[0].__class__ + result = None - for node in func_def.body: - result = evaluate_ast(node, new_state, tools) + for stmt in func_def.body: + result = evaluate_ast(stmt, func_state, tools) return result return new_func @@ -119,30 +141,6 @@ def create_class(class_name, class_bases, class_body): def evaluate_function_def(func_def, state, tools): - def create_function(func_def, state, tools): - def func(*args, **kwargs): - func_state = state.copy() - arg_names = [arg.arg for arg in func_def.args.args] - for name, value in zip(arg_names, args): - func_state[name] = value - if func_def.args.vararg: - vararg_name = func_def.args.vararg.arg - func_state[vararg_name] = args - if func_def.args.kwarg: - kwarg_name = func_def.args.kwarg.arg - func_state[kwarg_name] = kwargs - - # Update function state with self and __class__ - if func_def.args.args and func_def.args.args[0].arg == 'self': - if args: - func_state['self'] = args[0] - func_state['__class__'] = args[0].__class__ - - result = None - for stmt in func_def.body: - result = evaluate_ast(stmt, func_state, tools) - return result - return func tools[func_def.name] = create_function(func_def, state, tools) return tools[func_def.name] @@ -166,7 +164,6 @@ def evaluate_class_def(class_def, state, tools): return new_class - def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]): # Extract the target variable name and the operation if isinstance(expression.target, ast.Name): @@ -246,8 +243,13 @@ def evaluate_assign(assign, state, tools): elif isinstance(target, ast.Attribute): obj = evaluate_ast(target.value, state, tools) setattr(obj, target.attr, result) + elif isinstance(target, ast.Subscript): + obj = evaluate_ast(target.value, state, tools) + key = evaluate_ast(target.slice, state, tools) + obj[key] = result else: state[target.id] = result + else: if len(result) != len(var_names): raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.") @@ -256,7 +258,6 @@ def evaluate_assign(assign, state, tools): return result - def evaluate_call(call, state, tools): if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): raise InterpretorError( @@ -280,21 +281,21 @@ def evaluate_call(call, state, tools): raise InterpretorError( f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})." ) - + args = [evaluate_ast(arg, state, tools) for arg in call.args] kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} - if isinstance(func, type) and len(func.__module__.split('.')) > 1: # Check for user-defined classes + if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes # Instantiate the class using its constructor obj = func.__new__(func) # Create a new instance of the class - if hasattr(obj, '__init__'): # Check if the class has an __init__ method + if hasattr(obj, "__init__"): # Check if the class has an __init__ method obj.__init__(*args, **kwargs) # Call the __init__ method correctly return obj else: if func_name == "super": if not args: - if '__class__' in state and 'self' in state: - return super(state['__class__'], state['self']) + if "__class__" in state and "self" in state: + return super(state["__class__"], state["self"]) else: raise InterpretorError("super() needs at least one argument") cls = args[0] @@ -307,17 +308,15 @@ def evaluate_call(call, state, tools): return super(cls, instance) else: raise InterpretorError("super() takes at most 2 arguments") - - else: - # Assume it's a callable object - output = func(*args, **kwargs) - # Store logs of print statements + else: if func_name == "print": + output = " ".join(map(str, args)) state["print_outputs"] += output + "\n" - - - return output + return output + else: # Assume it's a callable object + output = func(*args, **kwargs) + return output def evaluate_subscript(subscript, state, tools): @@ -337,8 +336,6 @@ def evaluate_subscript(subscript, state, tools): return value[close_matches[0]] raise InterpretorError(f"Could not index {value} with '{index}'.") -import builtins -ERRORS = {name: getattr(builtins, name) for name in dir(builtins) if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)} def evaluate_name(name, state, tools): if name.id in state: @@ -464,7 +461,6 @@ def evaluate_try(try_node, state, tools): evaluate_ast(stmt, state, tools) - def evaluate_raise(raise_node, state, tools): if raise_node.exc is not None: exc = evaluate_ast(raise_node.exc, state, tools) diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index e5120e84988cd4..441ab721c49c5b 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -359,39 +359,38 @@ def test_tuple_target_in_iterator(self): result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result == "Samuel" - def test_classes(self): code = """ class Animal: species = "Generic Animal" - + def __init__(self, name, age): self.name = name self.age = age - + def sound(self): return "The animal makes a sound." - + def __str__(self): return f"{self.name}, {self.age} years old" class Dog(Animal): species = "Canine" - + def __init__(self, name, age, breed): super().__init__(name, age) self.breed = breed - + def sound(self): return "The dog barks." - + def __str__(self): return f"{self.name}, {self.age} years old, {self.breed}" class Cat(Animal): def sound(self): return "The cat meows." - + def __str__(self): return f"{self.name}, {self.age} years old, {self.species}" @@ -427,16 +426,16 @@ def method_that_raises(self): """ state = {} evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) - + # Assert results - assert state['dog1_sound'] == "The dog barks." - assert state['dog1_str'] == "Fido, 3 years old, Labrador" - assert state['dog2_sound'] == "The dog barks." - assert state['dog2_str'] == "Buddy, 5 years old, Golden Retriever" - assert state['cat_sound'] == "The cat meows." - assert state['cat_str'] == "Whiskers, 2 years old, Generic Animal" - assert state['num_animals'] == 3 - assert state['exception_message'] == "An error occurred" + assert state["dog1_sound"] == "The dog barks." + assert state["dog1_str"] == "Fido, 3 years old, Labrador" + assert state["dog2_sound"] == "The dog barks." + assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever" + assert state["cat_sound"] == "The cat meows." + assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal" + assert state["num_animals"] == 3 + assert state["exception_message"] == "An error occurred" def test_variable_args(self): code = """ @@ -461,4 +460,18 @@ def method_that_raises(self): """ state = {} evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) - assert state['exception_message'] == "An error occurred" + assert state["exception_message"] == "An error occurred" + + def test_subscript(self): + code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" + + state = {} + evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state) + assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} + + def test_print(self): + code = "print(min([1, 2, 3]))" + state = {} + result = evaluate_python_code(code, {"min": min, "print": print}, state=state) + assert result == "1" + assert state["print_outputs"] == "1\n"