Skip to content

Commit

Permalink
Remove risky imports, leave option for additional imports
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed May 27, 2024
1 parent db8afe9 commit 0ef1aa2
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/transformers/agents/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import ast
import difflib
from collections.abc import Mapping
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional


class InterpretorError(ValueError):
Expand All @@ -32,7 +32,6 @@ class InterpretorError(ValueError):
LIST_SAFE_MODULES = [
"random",
"collections",
"requests",
"math",
"time",
"queue",
Expand Down Expand Up @@ -353,7 +352,12 @@ def evaluate_listcomp(listcomp, state, tools):
return result


def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
def evaluate_ast(
expression: ast.AST,
state: Dict[str, Any],
tools: Dict[str, Callable],
authorized_imports: List[str] = LIST_SAFE_MODULES,
):
"""
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
set of functions.
Expand All @@ -369,6 +373,9 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
authorized_imports (`List[str]`):
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
Add more at your own risk!
"""
if isinstance(expression, ast.Assign):
# Assignement -> we evaluate the assignement which should update the state
Expand Down Expand Up @@ -475,7 +482,7 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
return result
elif isinstance(expression, ast.Import):
for alias in expression.names:
if alias.name in LIST_SAFE_MODULES:
if alias.name in authorized_imports:
module = __import__(alias.name)
state[alias.asname or alias.name] = module
else:
Expand All @@ -484,7 +491,7 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
elif isinstance(expression, ast.While):
return evaluate_while(expression, state, tools)
elif isinstance(expression, ast.ImportFrom):
if expression.module in LIST_SAFE_MODULES:
if expression.module in authorized_imports:
module = __import__(expression.module)
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
Expand All @@ -496,7 +503,9 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")


def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, state=None):
def evaluate_python_code(
code: str, tools: Optional[Dict[str, Callable]] = {}, state=None, authorized_imports: List[str] = LIST_SAFE_MODULES
):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
of functions.
Expand Down Expand Up @@ -524,7 +533,7 @@ def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, s
state["print_outputs"] = ""
for idx, node in enumerate(expression.body):
try:
line_result = evaluate_ast(node, state, tools)
line_result = evaluate_ast(node, state, tools, authorized_imports)
except InterpretorError as e:
msg = f"You tried to execute the following code:\n{code}\n"
msg += f"You got these outputs:\n{state['print_outputs']}\n"
Expand Down

0 comments on commit 0ef1aa2

Please sign in to comment.