Skip to content

Commit

Permalink
Interpreter tweaks: tuples and listcomp
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Apr 29, 2024
1 parent 86bae32 commit a33e3b5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ def __init__(
else self.default_tool_description_template,
**kwargs,
)
self.python_evaluator = evaluate_python_code

def step(self):
self.agent_memory = self.write_inner_memory_from_logs()
Expand Down Expand Up @@ -690,12 +691,12 @@ def step(self):
self.log.warn(f"====Agent is executing the code below:\n{code_action}\n====")
try:
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
result = evaluate_python_code(code_action, available_tools, state=self.state)
result = self.python_evaluator(code_action, available_tools, state=self.state)
information = self.state['print_outputs']
self.log.info(information)
self.logs[-1]["observation"] = information
except Exception as e:
error_msg = f"Executing code:\n{code_action}\nYielded error:\n{str(e)}\nMake sure to provide correct code."
error_msg = f"Failed while trying to execute the code below:\n{code_action}\Failed due to the following error:\n{str(e)}\nMake sure to provide correct code."
raise AgentExecutionError(error_msg)
for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer":
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/agents/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts)
elif isinstance(expression, ast.ListComp):
return evaluate_listcomp(expression, state, tools)
elif isinstance(expression, ast.UnaryOp):
operand = evaluate_ast(expression.operand, state, tools)
if isinstance(expression.op, ast.USub):
Expand Down Expand Up @@ -412,3 +416,17 @@ def evaluate_for(for_loop, state, tools):
if line_result is not None:
result = line_result
return result


def evaluate_listcomp(listcomp, state, tools):
result = []
vars = {}
for generator in listcomp.generators:
var_name = generator.target.id
iter_value = evaluate_ast(generator.iter, state, tools)
for value in iter_value:
vars[var_name] = value
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
elem = evaluate_ast(listcomp.elt, {**state, **vars}, tools)
result.append(elem)
return result
2 changes: 1 addition & 1 deletion tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_fake_react_code_agent(self):
assert output == '7.2904'
assert agent.logs[0]['task'] == "What is 2 multiplied by 3.6452?"
assert agent.logs[1]['observation'] == '\n12.511648652635412'
assert agent.logs[2]['tool_call'] == {'tool_arguments': 'final_answer(7.2904)\n', 'tool_name': 'code interpreter'}
assert agent.logs[2]['tool_call'] == {'tool_arguments': 'final_answer(7.2904)', 'tool_name': 'code interpreter'}

def test_fake_code_agent(self):
agent = CodeAgent(tools=[CalculatorTool()], llm_engine=fake_code_llm_oneshot)
Expand Down
12 changes: 11 additions & 1 deletion tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,14 @@ def test_imports(self):

code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
result = evaluate_python_code(code, {}, state={})
assert result == "lose"
assert result == "lose"

def test_tuples(self):
code = "x = (1, 2, 3)\nx[1]"
result = evaluate_python_code(code, {}, state={})
assert result == 2

def test_listcomp(self):
code = "x = [i for i in range(3)]"
result = evaluate_python_code(code, {"range":range}, state={})
assert result == [0, 1, 2]

0 comments on commit a33e3b5

Please sign in to comment.