Skip to content

Commit

Permalink
Add a few edge cases in interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed May 30, 2024
1 parent c2b5af6 commit c6772f8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 74 deletions.
11 changes: 4 additions & 7 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfEngine, MessageRole
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
from .python_interpreter import evaluate_python_code, LIST_SAFE_MODULES
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool,
Expand Down Expand Up @@ -411,10 +411,7 @@ def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
return memory

def get_succinct_logs(self):
return [
{key: value for key, value in log.items() if key != "agent_memory"}
for log in self.logs
]
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]

def extract_action(self, llm_output: str, split_token: str) -> str:
"""
Expand Down Expand Up @@ -576,7 +573,7 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs):
code_action,
available_tools,
state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
)
self.logger.info(self.state["print_outputs"])
return output
Expand Down Expand Up @@ -887,7 +884,7 @@ def step(self):
code_action,
available_tools,
state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
)
information = self.state["print_outputs"]
self.logger.warning("Print outputs:")
Expand Down
1 change: 0 additions & 1 deletion src/transformers/agents/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from enum import Enum
from typing import Dict, List
Expand Down
92 changes: 44 additions & 48 deletions src/transformers/agents/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import builtins
import difflib
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional
Expand All @@ -29,6 +30,13 @@ class InterpretorError(ValueError):
pass


ERRORS = {
name: getattr(builtins, name)
for name in dir(builtins)
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
}


LIST_SAFE_MODULES = [
"random",
"collections",
Expand Down Expand Up @@ -99,13 +107,27 @@ def evaluate_while(while_loop, state, tools):


def create_function(func_def, state, tools):
def new_func(*args):
new_state = state.copy()
for arg, val in zip(func_def.args.args, args):
new_state[arg.arg] = val
def new_func(*args, **kwargs):
func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args]
for name, value in zip(arg_names, args):
func_state[name] = value
if func_def.args.vararg:
vararg_name = func_def.args.vararg.arg
func_state[vararg_name] = args
if func_def.args.kwarg:
kwarg_name = func_def.args.kwarg.arg
func_state[kwarg_name] = kwargs

# Update function state with self and __class__
if func_def.args.args and func_def.args.args[0].arg == "self":
if args:
func_state["self"] = args[0]
func_state["__class__"] = args[0].__class__

result = None
for node in func_def.body:
result = evaluate_ast(node, new_state, tools)
for stmt in func_def.body:
result = evaluate_ast(stmt, func_state, tools)
return result

return new_func
Expand All @@ -119,30 +141,6 @@ def create_class(class_name, class_bases, class_body):


def evaluate_function_def(func_def, state, tools):
def create_function(func_def, state, tools):
def func(*args, **kwargs):
func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args]
for name, value in zip(arg_names, args):
func_state[name] = value
if func_def.args.vararg:
vararg_name = func_def.args.vararg.arg
func_state[vararg_name] = args
if func_def.args.kwarg:
kwarg_name = func_def.args.kwarg.arg
func_state[kwarg_name] = kwargs

# Update function state with self and __class__
if func_def.args.args and func_def.args.args[0].arg == 'self':
if args:
func_state['self'] = args[0]
func_state['__class__'] = args[0].__class__

result = None
for stmt in func_def.body:
result = evaluate_ast(stmt, func_state, tools)
return result
return func
tools[func_def.name] = create_function(func_def, state, tools)
return tools[func_def.name]

Expand All @@ -166,7 +164,6 @@ def evaluate_class_def(class_def, state, tools):
return new_class



def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
# Extract the target variable name and the operation
if isinstance(expression.target, ast.Name):
Expand Down Expand Up @@ -246,8 +243,13 @@ def evaluate_assign(assign, state, tools):
elif isinstance(target, ast.Attribute):
obj = evaluate_ast(target.value, state, tools)
setattr(obj, target.attr, result)
elif isinstance(target, ast.Subscript):
obj = evaluate_ast(target.value, state, tools)
key = evaluate_ast(target.slice, state, tools)
obj[key] = result
else:
state[target.id] = result

else:
if len(result) != len(var_names):
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
Expand All @@ -256,7 +258,6 @@ def evaluate_assign(assign, state, tools):
return result



def evaluate_call(call, state, tools):
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
raise InterpretorError(
Expand All @@ -280,21 +281,21 @@ def evaluate_call(call, state, tools):
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
)

args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}

if isinstance(func, type) and len(func.__module__.split('.')) > 1: # Check for user-defined classes
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
# Instantiate the class using its constructor
obj = func.__new__(func) # Create a new instance of the class
if hasattr(obj, '__init__'): # Check if the class has an __init__ method
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
obj.__init__(*args, **kwargs) # Call the __init__ method correctly
return obj
else:
if func_name == "super":
if not args:
if '__class__' in state and 'self' in state:
return super(state['__class__'], state['self'])
if "__class__" in state and "self" in state:
return super(state["__class__"], state["self"])
else:
raise InterpretorError("super() needs at least one argument")
cls = args[0]
Expand All @@ -307,17 +308,15 @@ def evaluate_call(call, state, tools):
return super(cls, instance)
else:
raise InterpretorError("super() takes at most 2 arguments")

else:
# Assume it's a callable object
output = func(*args, **kwargs)

# Store logs of print statements
else:
if func_name == "print":
output = " ".join(map(str, args))
state["print_outputs"] += output + "\n"


return output
return output
else: # Assume it's a callable object
output = func(*args, **kwargs)
return output


def evaluate_subscript(subscript, state, tools):
Expand All @@ -337,8 +336,6 @@ def evaluate_subscript(subscript, state, tools):
return value[close_matches[0]]
raise InterpretorError(f"Could not index {value} with '{index}'.")

import builtins
ERRORS = {name: getattr(builtins, name) for name in dir(builtins) if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)}

def evaluate_name(name, state, tools):
if name.id in state:
Expand Down Expand Up @@ -464,7 +461,6 @@ def evaluate_try(try_node, state, tools):
evaluate_ast(stmt, state, tools)



def evaluate_raise(raise_node, state, tools):
if raise_node.exc is not None:
exc = evaluate_ast(raise_node.exc, state, tools)
Expand Down
49 changes: 31 additions & 18 deletions tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,39 +359,38 @@ def test_tuple_target_in_iterator(self):
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "Samuel"


def test_classes(self):
code = """
class Animal:
species = "Generic Animal"
def __init__(self, name, age):
self.name = name
self.age = age
def sound(self):
return "The animal makes a sound."
def __str__(self):
return f"{self.name}, {self.age} years old"
class Dog(Animal):
species = "Canine"
def __init__(self, name, age, breed):
super().__init__(name, age)
self.breed = breed
def sound(self):
return "The dog barks."
def __str__(self):
return f"{self.name}, {self.age} years old, {self.breed}"
class Cat(Animal):
def sound(self):
return "The cat meows."
def __str__(self):
return f"{self.name}, {self.age} years old, {self.species}"
Expand Down Expand Up @@ -427,16 +426,16 @@ def method_that_raises(self):
"""
state = {}
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)

# Assert results
assert state['dog1_sound'] == "The dog barks."
assert state['dog1_str'] == "Fido, 3 years old, Labrador"
assert state['dog2_sound'] == "The dog barks."
assert state['dog2_str'] == "Buddy, 5 years old, Golden Retriever"
assert state['cat_sound'] == "The cat meows."
assert state['cat_str'] == "Whiskers, 2 years old, Generic Animal"
assert state['num_animals'] == 3
assert state['exception_message'] == "An error occurred"
assert state["dog1_sound"] == "The dog barks."
assert state["dog1_str"] == "Fido, 3 years old, Labrador"
assert state["dog2_sound"] == "The dog barks."
assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever"
assert state["cat_sound"] == "The cat meows."
assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal"
assert state["num_animals"] == 3
assert state["exception_message"] == "An error occurred"

def test_variable_args(self):
code = """
Expand All @@ -461,4 +460,18 @@ def method_that_raises(self):
"""
state = {}
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
assert state['exception_message'] == "An error occurred"
assert state["exception_message"] == "An error occurred"

def test_subscript(self):
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"

state = {}
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}

def test_print(self):
code = "print(min([1, 2, 3]))"
state = {}
result = evaluate_python_code(code, {"min": min, "print": print}, state=state)
assert result == "1"
assert state["print_outputs"] == "1\n"

0 comments on commit c6772f8

Please sign in to comment.