Skip to content

Commit

Permalink
Fix code quality issues
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Jun 28, 2024
1 parent 31583a7 commit 1dce68f
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/transformers/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pathlib
import tempfile
import uuid

import numpy as np

from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION,
SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_PLAN,
USER_PROMPT_PLAN,
SYSTEM_PROMPT_FACTS_UPDATE,
SYSTEM_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN,
USER_PROMPT_PLAN_UPDATE,
PLAN_UPDATE_FINAL_PLAN_REDACTION
)
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import (
Expand Down Expand Up @@ -419,11 +419,11 @@ def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) ->
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
memory.append(thought_message)
if "facts" in step_log:
thought_message = {"role": MessageRole.ASSISTANT, "content": f"[FACTS LIST]:\n" + step_log["facts"].strip()}
thought_message = {"role": MessageRole.ASSISTANT, "content": "[FACTS LIST]:\n" + step_log["facts"].strip()}
memory.append(thought_message)

if "plan" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": f"[PLAN]:\n" + step_log["plan"].strip()}
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
memory.append(thought_message)

if "tool_call" in step_log and summary_mode:
Expand Down Expand Up @@ -783,7 +783,7 @@ def planning_step(self, task, is_first_step=False, iteration: int = None):
self.logger.debug(final_plan_redaction)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(summary_mode=False) # This will not log the plan but will log facts

# Redact updated facts
facts_update_system_prompt = {
"role": MessageRole.SYSTEM,
Expand All @@ -809,7 +809,7 @@ def planning_step(self, task, is_first_step=False, iteration: int = None):
remaining_steps = (self.max_iterations - iteration)
)}
plan_update = self.llm_engine([plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=['<end_plan>'])

# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
final_facts_redaction = f"""Here is the updated list of the facts that I know:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
"""

SYSTEM_PROMPT_FACTS = """Below I will present you a task.
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
Expand All @@ -390,7 +390,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""

Expand Down Expand Up @@ -472,4 +472,4 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
Here is my new/updated plan of action to solve the task:
```
{plan_update}
```"""
```"""
8 changes: 5 additions & 3 deletions src/transformers/agents/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import ast
import builtins
import difflib
import numpy as np
import pandas as pd
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import pandas as pd


class InterpreterError(ValueError):
"""
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
Expand Down Expand Up @@ -363,7 +365,7 @@ def evaluate_call(call, state, tools):
if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name)

elif isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in state:
Expand Down
4 changes: 2 additions & 2 deletions tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from transformers import load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import BASE_PYTHON_TOOLS, LIST_SAFE_MODULES
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code

from .test_tools_common import ToolTesterMixin
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_additional_imports(self):
code = "import matplotlib.pyplot as plt"
evaluate_python_code(code, authorized_imports=["matplotlib.pyplot"], state={})
evaluate_python_code(code, authorized_imports=["matplotlib"], state={})
with pytest.raises(InterpreterError) as e:
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["pyplot"], state={})


Expand Down

0 comments on commit 1dce68f

Please sign in to comment.