diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 328f9efb4d106e..8afec2ab7d2349 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -977,7 +977,6 @@ def __init__( self.python_evaluator = evaluate_python_code self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) - print(self.authorized_imports) self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) self.custom_tools = {} diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 636e37bad7ecb0..bcef817436ff6b 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -18,6 +18,7 @@ import builtins import difflib from collections.abc import Mapping +from importlib import import_module from typing import Any, Callable, Dict, List, Optional import numpy as np @@ -448,6 +449,8 @@ def evaluate_subscript(subscript, state, tools, custom_tools): return parent_object.loc[index] if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)): return value[index] + elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy): + return value[index] elif isinstance(index, slice): return value[index] elif isinstance(value, (list, tuple)): @@ -666,7 +669,7 @@ def check_module_authorized(module_name): if isinstance(expression, ast.Import): for alias in expression.names: if check_module_authorized(alias.name): - module = __import__(alias.name) + module = import_module(alias.name) state[alias.asname or alias.name] = module else: raise InterpreterError( diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 680c9a696ef6e5..a45f661c6e8d1c 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -414,6 +414,13 @@ def test_imports(self): result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result == "LATIN CAPITAL LETTER A" + # Test submodules are handled properly, thus not raising error + code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) + + code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) + def test_additional_imports(self): code = "import numpy as np" evaluate_python_code(code, authorized_imports=["numpy"], state={}) @@ -771,6 +778,17 @@ def test_pandas(self): result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) assert np.array_equal(result.values[0], [104, 1]) + code = """import pandas as pd +data = pd.DataFrame.from_dict([ + {"Pclass": 1, "Survived": 1}, + {"Pclass": 2, "Survived": 0}, + {"Pclass": 2, "Survived": 1} +]) +survival_rate_by_class = data.groupby('Pclass')['Survived'].mean() +""" + result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"]) + assert result.values[1] == 0.5 + def test_starred(self): code = """ from math import radians, sin, cos, sqrt, atan2