Skip to content

Commit

Permalink
Fix interpreter errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Jul 9, 2024
1 parent 50cad81 commit 9285705
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,6 @@ def __init__(
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
print(self.authorized_imports)
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.custom_tools = {}

Expand Down
5 changes: 4 additions & 1 deletion src/transformers/agents/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import builtins
import difflib
from collections.abc import Mapping
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional

import numpy as np
Expand Down Expand Up @@ -448,6 +449,8 @@ def evaluate_subscript(subscript, state, tools, custom_tools):
return parent_object.loc[index]
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
return value[index]
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
return value[index]
elif isinstance(index, slice):
return value[index]
elif isinstance(value, (list, tuple)):
Expand Down Expand Up @@ -666,7 +669,7 @@ def check_module_authorized(module_name):
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
module = __import__(alias.name)
module = import_module(alias.name)
state[alias.asname or alias.name] = module
else:
raise InterpreterError(
Expand Down
18 changes: 18 additions & 0 deletions tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,13 @@ def test_imports(self):
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "LATIN CAPITAL LETTER A"

# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])

code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])

def test_additional_imports(self):
code = "import numpy as np"
evaluate_python_code(code, authorized_imports=["numpy"], state={})
Expand Down Expand Up @@ -771,6 +778,17 @@ def test_pandas(self):
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert np.array_equal(result.values[0], [104, 1])

code = """import pandas as pd
data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1},
{"Pclass": 2, "Survived": 0},
{"Pclass": 2, "Survived": 1}
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
assert result.values[1] == 0.5

def test_starred(self):
code = """
from math import radians, sin, cos, sqrt, atan2
Expand Down

0 comments on commit 9285705

Please sign in to comment.