Skip to content

Commit

Permalink
Fix code quality issues
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Jun 28, 2024
1 parent 31583a7 commit 8efac17
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/transformers/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pathlib
import tempfile
import uuid

import numpy as np

from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
Expand Down
68 changes: 43 additions & 25 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION,
SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_PLAN,
USER_PROMPT_PLAN,
SYSTEM_PROMPT_FACTS_UPDATE,
SYSTEM_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN,
USER_PROMPT_PLAN_UPDATE,
PLAN_UPDATE_FINAL_PLAN_REDACTION
)
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import (
Expand Down Expand Up @@ -419,15 +419,21 @@ def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) ->
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
memory.append(thought_message)
if "facts" in step_log:
thought_message = {"role": MessageRole.ASSISTANT, "content": f"[FACTS LIST]:\n" + step_log["facts"].strip()}
thought_message = {
"role": MessageRole.ASSISTANT,
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
}
memory.append(thought_message)

if "plan" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": f"[PLAN]:\n" + step_log["plan"].strip()}
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
memory.append(thought_message)

if "tool_call" in step_log and summary_mode:
tool_call_message = {"role": MessageRole.ASSISTANT, "content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip()}
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
}
memory.append(tool_call_message)

if "error" in step_log or "observation" in step_log:
Expand Down Expand Up @@ -747,28 +753,35 @@ def direct_run(self, task: str, **kwargs):

return final_answer


def planning_step(self, task, is_first_step=False, iteration: int = None):
"""
Plan the next steps to reach the objective.
"""
if is_first_step:
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
message_prompt_task = {"role": MessageRole.USER, "content": f"""Here is the task:
message_prompt_task = {
"role": MessageRole.USER,
"content": f"""Here is the task:
```
{task}
```
Now begin!"""}
Now begin!""",
}

answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])

message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN}
message_user_prompt_plan = {"role": MessageRole.USER, "content": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
answer_facts=answer_facts
)}
answer_plan = self.llm_engine([message_system_prompt_plan, message_user_prompt_plan], stop_sequences=['<end_plan>'])
message_user_prompt_plan = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
answer_facts=answer_facts,
),
}
answer_plan = self.llm_engine(
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
)

final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
Expand All @@ -782,8 +795,10 @@ def planning_step(self, task, is_first_step=False, iteration: int = None):
self.logger.debug("===== Initial plan: =====")
self.logger.debug(final_plan_redaction)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(summary_mode=False) # This will not log the plan but will log facts

agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
) # This will not log the plan but will log facts

# Redact updated facts
facts_update_system_prompt = {
"role": MessageRole.SYSTEM,
Expand All @@ -806,19 +821,20 @@ def planning_step(self, task, is_first_step=False, iteration: int = None):
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
facts_update=facts_update,
remaining_steps = (self.max_iterations - iteration)
)}
plan_update = self.llm_engine([plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=['<end_plan>'])

remaining_steps=(self.max_iterations - iteration),
),
}
plan_update = self.llm_engine(
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
)

# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
final_facts_redaction = f"""Here is the updated list of the facts that I know:
```
{facts_update}
```"""
self.logs.append(
{"plan": final_plan_redaction, "facts": final_facts_redaction}
)
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Updated plan: =====")
self.logger.debug(final_plan_redaction)
print("UPDATED PLAN:", final_plan_redaction)
Expand Down Expand Up @@ -893,7 +909,9 @@ def step(self):
if isinstance(arguments, dict):
if "answer" in arguments:
answer = arguments["answer"]
if isinstance(answer, str) and answer in self.state.keys(): # if the answer is a state variable, return the value
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
else:
answer = arguments
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
"""

SYSTEM_PROMPT_FACTS = """Below I will present you a task.
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
Expand All @@ -390,7 +390,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""

Expand Down Expand Up @@ -472,4 +472,4 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
Here is my new/updated plan of action to solve the task:
```
{plan_update}
```"""
```"""
44 changes: 29 additions & 15 deletions src/transformers/agents/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@
import ast
import builtins
import difflib
import numpy as np
import pandas as pd
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional

import numpy as np

from ..utils import is_pandas_available


if is_pandas_available():
import pandas as pd


class InterpreterError(ValueError):
"""
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
Expand Down Expand Up @@ -54,6 +61,7 @@ class InterpreterError(ValueError):
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000


class BreakException(Exception):
pass

Expand Down Expand Up @@ -335,7 +343,7 @@ def set_value(target, value, state, tools):
state[target.id] = value
elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple):
if hasattr(value, '__iter__') and not isinstance(value, (str, bytes)):
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
value = tuple(value)
else:
raise InterpreterError("Cannot unpack non-tuple value")
Expand All @@ -354,16 +362,14 @@ def set_value(target, value, state, tools):

def evaluate_call(call, state, tools):
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
raise InterpreterError(
f"This is not a correct function: {call.func})."
)
raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(call.func.value, state, tools)
func_name = call.func.attr
if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name)

elif isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in state:
Expand All @@ -388,13 +394,12 @@ def evaluate_call(call, state, tools):
for arg in call.args:
if isinstance(arg, ast.Starred):
unpacked = evaluate_ast(arg.value, state, tools)
if not hasattr(unpacked, '__iter__') or isinstance(unpacked, (str, bytes)):
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
args.extend(unpacked)
else:
args.append(evaluate_ast(arg, state, tools))


kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}

if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
Expand Down Expand Up @@ -472,6 +477,7 @@ def evaluate_name(name, state, tools):
return state[close_matches[0]]
raise InterpreterError(f"The variable `{name.id}` is not defined.")


def evaluate_condition(condition, state, tools):
left = evaluate_ast(condition.left, state, tools)
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
Expand Down Expand Up @@ -651,8 +657,8 @@ def evaluate_with(with_node, state, tools):

def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name):
module_path = module_name.split('.')
module_subpaths = ['.'.join(module_path[:i]) for i in range(1, len(module_path) + 1)]
module_path = module_name.split(".")
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)

if isinstance(expression, ast.Import):
Expand All @@ -661,7 +667,9 @@ def check_module_authorized(module_name):
module = __import__(alias.name)
state[alias.asname or alias.name] = module
else:
raise InterpreterError(f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}")
raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
)
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
Expand All @@ -672,6 +680,7 @@ def check_module_authorized(module_name):
raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None


def evaluate_dictcomp(dictcomp, state, tools):
result = {}
for gen in dictcomp.generators:
Expand Down Expand Up @@ -713,7 +722,9 @@ def evaluate_ast(
"""
global OPERATIONS_COUNT
if OPERATIONS_COUNT >= MAX_OPERATIONS:
raise InterpreterError(f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations.")
raise InterpreterError(
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
)
OPERATIONS_COUNT += 1
if isinstance(expression, ast.Assign):
# Assignement -> we evaluate the assignement which should update the state
Expand Down Expand Up @@ -857,7 +868,7 @@ def evaluate_python_code(
try:
result = evaluate_ast(node, state, tools, authorized_imports)
except InterpreterError as e:
msg=""
msg = ""
if len(PRINT_OUTPUTS) > 0:
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
Expand All @@ -869,6 +880,9 @@ def evaluate_python_code(
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
state["print_outputs"] = PRINT_OUTPUTS
else:
state["print_outputs"] = PRINT_OUTPUTS[:MAX_LEN_OUTPUT] + f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
state["print_outputs"] = (
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
)

return result
24 changes: 12 additions & 12 deletions tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from transformers import load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import BASE_PYTHON_TOOLS, LIST_SAFE_MODULES
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code

from .test_tools_common import ToolTesterMixin
Expand Down Expand Up @@ -274,8 +274,9 @@ def calculate_isbn_10_check_digit(number):
print(check_digits)
"""
state = {}
evaluate_python_code(code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state)

evaluate_python_code(
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
)

def test_listcomp(self):
code = "x = [i for i in range(3)]"
Expand Down Expand Up @@ -306,16 +307,16 @@ def test_dictcomp(self):
result = evaluate_python_code(code, {"range": range}, state={})
assert result == {0: 0, 1: 1, 2: 4}

code= "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert result == {102: 'b'}
assert result == {102: "b"}

code = """
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
"""
result = evaluate_python_code(code, {}, state={})
assert result == {'A': ('a', 'b'), 'B': ('a', 'b')}
assert result == {"A": ("a", "b"), "B": ("a", "b")}

def test_tuple_assignment(self):
code = "a, b = 0, 1\nb"
Expand Down Expand Up @@ -420,10 +421,9 @@ def test_additional_imports(self):
code = "import matplotlib.pyplot as plt"
evaluate_python_code(code, authorized_imports=["matplotlib.pyplot"], state={})
evaluate_python_code(code, authorized_imports=["matplotlib"], state={})
with pytest.raises(InterpreterError) as e:
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["pyplot"], state={})


def test_multiple_comparators(self):
code = "0 <= -1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
Expand Down Expand Up @@ -729,7 +729,7 @@ def returns_none(a):
assert result is None

def test_nested_for_loop(self):
code="""
code = """
all_res = []
for i in range(10):
subres = []
Expand All @@ -742,7 +742,7 @@ def test_nested_for_loop(self):
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]

def test_pandas(self):
code = """
Expand Down Expand Up @@ -794,7 +794,7 @@ def haversine(lat1, lon1, lat2, lon2):
assert round(result, 1) == 622395.4

def test_for(self):
code="""
code = """
shifts = {
"Worker A": ("6:45 pm", "8:00 pm"),
"Worker B": ("10:00 am", "11:45 am")
Expand All @@ -806,4 +806,4 @@ def test_for(self):
shift_intervals
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {'Worker A': '8:00 pm', 'Worker B': '11:45 am'}
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}

0 comments on commit 8efac17

Please sign in to comment.