diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 64cc4a806ce4d4..328f9efb4d106e 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -229,7 +229,7 @@ def add_tool(self, tool: Tool): The tool to add to the toolbox. """ if tool.name in self._tools: - raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.") + raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.") self._tools[tool.name] = tool def remove_tool(self, tool_name: str): @@ -621,6 +621,7 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): output = self.python_evaluator( code_action, available_tools, + custom_tools={}, state=self.state, authorized_imports=self.authorized_imports, ) @@ -976,11 +977,9 @@ def __init__( self.python_evaluator = evaluate_python_code self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) + print(self.authorized_imports) self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) - self.available_tools = { - **BASE_PYTHON_TOOLS.copy(), - **self.toolbox.tools, - } # This list can be augmented by the code agent creating some new functions + self.custom_tools = {} def step(self): """ @@ -1032,7 +1031,11 @@ def step(self): try: result = self.python_evaluator( code_action, - tools=self.available_tools, + tools={ + **BASE_PYTHON_TOOLS.copy(), + **self.toolbox.tools, + }, + custom_tools=self.custom_tools, state=self.state, authorized_imports=self.authorized_imports, ) diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 0d67ae72fa32b7..636e37bad7ecb0 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -84,8 +84,8 @@ def get_iterable(obj): raise InterpreterError("Object is not iterable") -def evaluate_unaryop(expression, state, tools): - operand = evaluate_ast(expression.operand, state, tools) +def evaluate_unaryop(expression, state, tools, custom_tools): + operand = evaluate_ast(expression.operand, state, tools, custom_tools) if isinstance(expression.op, ast.USub): return -operand elif isinstance(expression.op, ast.UAdd): @@ -98,25 +98,25 @@ def evaluate_unaryop(expression, state, tools): raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") -def evaluate_lambda(lambda_expression, state, tools): +def evaluate_lambda(lambda_expression, state, tools, custom_tools): args = [arg.arg for arg in lambda_expression.args.args] def lambda_func(*values): new_state = state.copy() for arg, value in zip(args, values): new_state[arg] = value - return evaluate_ast(lambda_expression.body, new_state, tools) + return evaluate_ast(lambda_expression.body, new_state, tools, custom_tools) return lambda_func -def evaluate_while(while_loop, state, tools): +def evaluate_while(while_loop, state, tools, custom_tools): max_iterations = 1000 iterations = 0 - while evaluate_ast(while_loop.test, state, tools): + while evaluate_ast(while_loop.test, state, tools, custom_tools): for node in while_loop.body: try: - evaluate_ast(node, state, tools) + evaluate_ast(node, state, tools, custom_tools) except BreakException: return None except ContinueException: @@ -127,11 +127,11 @@ def evaluate_while(while_loop, state, tools): return None -def create_function(func_def, state, tools): +def create_function(func_def, state, tools, custom_tools): def new_func(*args, **kwargs): func_state = state.copy() arg_names = [arg.arg for arg in func_def.args.args] - default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults] + default_values = [evaluate_ast(d, state, tools, custom_tools) for d in func_def.args.defaults] # Apply default values defaults = dict(zip(arg_names[-len(default_values) :], default_values)) @@ -167,7 +167,7 @@ def new_func(*args, **kwargs): result = None try: for stmt in func_def.body: - result = evaluate_ast(stmt, func_state, tools) + result = evaluate_ast(stmt, func_state, tools, custom_tools) except ReturnException as e: result = e.value return result @@ -182,25 +182,25 @@ def create_class(class_name, class_bases, class_body): return type(class_name, tuple(class_bases), class_dict) -def evaluate_function_def(func_def, state, tools): - tools[func_def.name] = create_function(func_def, state, tools) - return tools[func_def.name] +def evaluate_function_def(func_def, state, tools, custom_tools): + custom_tools[func_def.name] = create_function(func_def, state, tools, custom_tools) + return custom_tools[func_def.name] -def evaluate_class_def(class_def, state, tools): +def evaluate_class_def(class_def, state, tools, custom_tools): class_name = class_def.name - bases = [evaluate_ast(base, state, tools) for base in class_def.bases] + bases = [evaluate_ast(base, state, tools, custom_tools) for base in class_def.bases] class_dict = {} for stmt in class_def.body: if isinstance(stmt, ast.FunctionDef): - class_dict[stmt.name] = evaluate_function_def(stmt, state, tools) + class_dict[stmt.name] = evaluate_function_def(stmt, state, tools, custom_tools) elif isinstance(stmt, ast.Assign): for target in stmt.targets: if isinstance(target, ast.Name): - class_dict[target.id] = evaluate_ast(stmt.value, state, tools) + class_dict[target.id] = evaluate_ast(stmt.value, state, tools, custom_tools) elif isinstance(target, ast.Attribute): - class_dict[target.attr] = evaluate_ast(stmt.value, state, tools) + class_dict[target.attr] = evaluate_ast(stmt.value, state, tools, custom_tools) else: raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") @@ -209,17 +209,17 @@ def evaluate_class_def(class_def, state, tools): return new_class -def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]): +def evaluate_augassign(expression, state, tools, custom_tools): # Helper function to get current value and set new value based on the target type def get_current_value(target): if isinstance(target, ast.Name): return state.get(target.id, 0) elif isinstance(target, ast.Subscript): - obj = evaluate_ast(target.value, state, tools) - key = evaluate_ast(target.slice, state, tools) + obj = evaluate_ast(target.value, state, tools, custom_tools) + key = evaluate_ast(target.slice, state, tools, custom_tools) return obj[key] elif isinstance(target, ast.Attribute): - obj = evaluate_ast(target.value, state, tools) + obj = evaluate_ast(target.value, state, tools, custom_tools) return getattr(obj, target.attr) elif isinstance(target, ast.Tuple): return tuple(get_current_value(elt) for elt in target.elts) @@ -229,7 +229,7 @@ def get_current_value(target): raise InterpreterError("AugAssign not supported for {type(target)} targets.") current_value = get_current_value(expression.target) - value_to_add = evaluate_ast(expression.value, state, tools) + value_to_add = evaluate_ast(expression.value, state, tools, custom_tools) # Determine the operation and apply it if isinstance(expression.op, ast.Add): @@ -265,28 +265,28 @@ def get_current_value(target): raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") # Update the state - set_value(expression.target, updated_value, state, tools) + set_value(expression.target, updated_value, state, tools, custom_tools) return updated_value -def evaluate_boolop(node, state, tools): +def evaluate_boolop(node, state, tools, custom_tools): if isinstance(node.op, ast.And): for value in node.values: - if not evaluate_ast(value, state, tools): + if not evaluate_ast(value, state, tools, custom_tools): return False return True elif isinstance(node.op, ast.Or): for value in node.values: - if evaluate_ast(value, state, tools): + if evaluate_ast(value, state, tools, custom_tools): return True return False -def evaluate_binop(binop, state, tools): +def evaluate_binop(binop, state, tools, custom_tools): # Recursively evaluate the left and right operands - left_val = evaluate_ast(binop.left, state, tools) - right_val = evaluate_ast(binop.right, state, tools) + left_val = evaluate_ast(binop.left, state, tools, custom_tools) + right_val = evaluate_ast(binop.right, state, tools, custom_tools) # Determine the operation based on the type of the operator in the BinOp if isinstance(binop.op, ast.Add): @@ -317,11 +317,11 @@ def evaluate_binop(binop, state, tools): raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") -def evaluate_assign(assign, state, tools): - result = evaluate_ast(assign.value, state, tools) +def evaluate_assign(assign, state, tools, custom_tools): + result = evaluate_ast(assign.value, state, tools, custom_tools) if len(assign.targets) == 1: target = assign.targets[0] - set_value(target, result, state, tools) + set_value(target, result, state, tools, custom_tools) else: if len(assign.targets) != len(result): raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") @@ -332,11 +332,11 @@ def evaluate_assign(assign, state, tools): else: expanded_values.append(result) for tgt, val in zip(assign.targets, expanded_values): - set_value(tgt, val, state, tools) + set_value(tgt, val, state, tools, custom_tools) return result -def set_value(target, value, state, tools): +def set_value(target, value, state, tools, custom_tools): if isinstance(target, ast.Name): if target.id in tools: raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") @@ -350,21 +350,21 @@ def set_value(target, value, state, tools): if len(target.elts) != len(value): raise InterpreterError("Cannot unpack tuple of wrong size") for i, elem in enumerate(target.elts): - set_value(elem, value[i], state, tools) + set_value(elem, value[i], state, tools, custom_tools) elif isinstance(target, ast.Subscript): - obj = evaluate_ast(target.value, state, tools) - key = evaluate_ast(target.slice, state, tools) + obj = evaluate_ast(target.value, state, tools, custom_tools) + key = evaluate_ast(target.slice, state, tools, custom_tools) obj[key] = value elif isinstance(target, ast.Attribute): - obj = evaluate_ast(target.value, state, tools) + obj = evaluate_ast(target.value, state, tools, custom_tools) setattr(obj, target.attr, value) -def evaluate_call(call, state, tools): +def evaluate_call(call, state, tools, custom_tools): if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): raise InterpreterError(f"This is not a correct function: {call.func}).") if isinstance(call.func, ast.Attribute): - obj = evaluate_ast(call.func.value, state, tools) + obj = evaluate_ast(call.func.value, state, tools, custom_tools) func_name = call.func.attr if not hasattr(obj, func_name): raise InterpreterError(f"Object {obj} has no attribute {func_name}") @@ -376,6 +376,8 @@ def evaluate_call(call, state, tools): func = state[func_name] elif func_name in tools: func = tools[func_name] + elif func_name in custom_tools: + func = custom_tools[func_name] elif func_name in ERRORS: func = ERRORS[func_name] else: @@ -386,21 +388,21 @@ def evaluate_call(call, state, tools): args = [] for arg in call.args: if isinstance(arg, ast.Starred): - args.extend(evaluate_ast(arg.value, state, tools)) + args.extend(evaluate_ast(arg.value, state, tools, custom_tools)) else: - args.append(evaluate_ast(arg, state, tools)) + args.append(evaluate_ast(arg, state, tools, custom_tools)) args = [] for arg in call.args: if isinstance(arg, ast.Starred): - unpacked = evaluate_ast(arg.value, state, tools) + unpacked = evaluate_ast(arg.value, state, tools, custom_tools) if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)): raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}") args.extend(unpacked) else: - args.append(evaluate_ast(arg, state, tools)) + args.append(evaluate_ast(arg, state, tools, custom_tools)) - kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} + kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools, custom_tools) for keyword in call.keywords} if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes # Instantiate the class using its constructor @@ -437,9 +439,9 @@ def evaluate_call(call, state, tools): return output -def evaluate_subscript(subscript, state, tools): - index = evaluate_ast(subscript.slice, state, tools) - value = evaluate_ast(subscript.value, state, tools) +def evaluate_subscript(subscript, state, tools, custom_tools): + index = evaluate_ast(subscript.slice, state, tools, custom_tools) + value = evaluate_ast(subscript.value, state, tools, custom_tools) if isinstance(value, pd.core.indexing._LocIndexer): parent_object = value.obj @@ -465,7 +467,7 @@ def evaluate_subscript(subscript, state, tools): raise InterpreterError(f"Could not index {value} with '{index}'.") -def evaluate_name(name, state, tools): +def evaluate_name(name, state, tools, custom_tools): if name.id in state: return state[name.id] elif name.id in tools: @@ -478,9 +480,9 @@ def evaluate_name(name, state, tools): raise InterpreterError(f"The variable `{name.id}` is not defined.") -def evaluate_condition(condition, state, tools): - left = evaluate_ast(condition.left, state, tools) - comparators = [evaluate_ast(c, state, tools) for c in condition.comparators] +def evaluate_condition(condition, state, tools, custom_tools): + left = evaluate_ast(condition.left, state, tools, custom_tools) + comparators = [evaluate_ast(c, state, tools, custom_tools) for c in condition.comparators] ops = [type(op) for op in condition.ops] result = True @@ -519,30 +521,30 @@ def evaluate_condition(condition, state, tools): return result if isinstance(result, (bool, pd.Series)) else result.all() -def evaluate_if(if_statement, state, tools): +def evaluate_if(if_statement, state, tools, custom_tools): result = None - test_result = evaluate_ast(if_statement.test, state, tools) + test_result = evaluate_ast(if_statement.test, state, tools, custom_tools) if test_result: for line in if_statement.body: - line_result = evaluate_ast(line, state, tools) + line_result = evaluate_ast(line, state, tools, custom_tools) if line_result is not None: result = line_result else: for line in if_statement.orelse: - line_result = evaluate_ast(line, state, tools) + line_result = evaluate_ast(line, state, tools, custom_tools) if line_result is not None: result = line_result return result -def evaluate_for(for_loop, state, tools): +def evaluate_for(for_loop, state, tools, custom_tools): result = None - iterator = evaluate_ast(for_loop.iter, state, tools) + iterator = evaluate_ast(for_loop.iter, state, tools, custom_tools) for counter in iterator: - set_value(for_loop.target, counter, state, tools) + set_value(for_loop.target, counter, state, tools, custom_tools) for node in for_loop.body: try: - line_result = evaluate_ast(node, state, tools) + line_result = evaluate_ast(node, state, tools, custom_tools) if line_result is not None: result = line_result except BreakException: @@ -555,12 +557,12 @@ def evaluate_for(for_loop, state, tools): return result -def evaluate_listcomp(listcomp, state, tools): +def evaluate_listcomp(listcomp, state, tools, custom_tools): def inner_evaluate(generators, index, current_state): if index >= len(generators): - return [evaluate_ast(listcomp.elt, current_state, tools)] + return [evaluate_ast(listcomp.elt, current_state, tools, custom_tools)] generator = generators[index] - iter_value = evaluate_ast(generator.iter, current_state, tools) + iter_value = evaluate_ast(generator.iter, current_state, tools, custom_tools) result = [] for value in iter_value: new_state = current_state.copy() @@ -569,46 +571,46 @@ def inner_evaluate(generators, index, current_state): new_state[elem.id] = value[idx] else: new_state[generator.target.id] = value - if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs): + if all(evaluate_ast(if_clause, new_state, tools, custom_tools) for if_clause in generator.ifs): result.extend(inner_evaluate(generators, index + 1, new_state)) return result return inner_evaluate(listcomp.generators, 0, state) -def evaluate_try(try_node, state, tools): +def evaluate_try(try_node, state, tools, custom_tools): try: for stmt in try_node.body: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, tools, custom_tools) except Exception as e: matched = False for handler in try_node.handlers: - if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)): + if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools, custom_tools)): matched = True if handler.name: state[handler.name] = e for stmt in handler.body: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, tools, custom_tools) break if not matched: raise e else: if try_node.orelse: for stmt in try_node.orelse: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, tools, custom_tools) finally: if try_node.finalbody: for stmt in try_node.finalbody: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, tools, custom_tools) -def evaluate_raise(raise_node, state, tools): +def evaluate_raise(raise_node, state, tools, custom_tools): if raise_node.exc is not None: - exc = evaluate_ast(raise_node.exc, state, tools) + exc = evaluate_ast(raise_node.exc, state, tools, custom_tools) else: exc = None if raise_node.cause is not None: - cause = evaluate_ast(raise_node.cause, state, tools) + cause = evaluate_ast(raise_node.cause, state, tools, custom_tools) else: cause = None if exc is not None: @@ -620,11 +622,11 @@ def evaluate_raise(raise_node, state, tools): raise InterpreterError("Re-raise is not supported without an active exception") -def evaluate_assert(assert_node, state, tools): - test_result = evaluate_ast(assert_node.test, state, tools) +def evaluate_assert(assert_node, state, tools, custom_tools): + test_result = evaluate_ast(assert_node.test, state, tools, custom_tools) if not test_result: if assert_node.msg: - msg = evaluate_ast(assert_node.msg, state, tools) + msg = evaluate_ast(assert_node.msg, state, tools, custom_tools) raise AssertionError(msg) else: # Include the failing condition in the assertion message @@ -632,10 +634,10 @@ def evaluate_assert(assert_node, state, tools): raise AssertionError(f"Assertion failed: {test_code}") -def evaluate_with(with_node, state, tools): +def evaluate_with(with_node, state, tools, custom_tools): contexts = [] for item in with_node.items: - context_expr = evaluate_ast(item.context_expr, state, tools) + context_expr = evaluate_ast(item.context_expr, state, tools, custom_tools) if item.optional_vars: state[item.optional_vars.id] = context_expr.__enter__() contexts.append(state[item.optional_vars.id]) @@ -645,7 +647,7 @@ def evaluate_with(with_node, state, tools): try: for stmt in with_node.body: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, tools, custom_tools) except Exception as e: for context in reversed(contexts): context.__exit__(type(e), e, e.__traceback__) @@ -681,16 +683,16 @@ def check_module_authorized(module_name): return None -def evaluate_dictcomp(dictcomp, state, tools): +def evaluate_dictcomp(dictcomp, state, tools, custom_tools): result = {} for gen in dictcomp.generators: - iter_value = evaluate_ast(gen.iter, state, tools) + iter_value = evaluate_ast(gen.iter, state, tools, custom_tools) for value in iter_value: new_state = state.copy() - set_value(gen.target, value, new_state, tools) - if all(evaluate_ast(if_clause, new_state, tools) for if_clause in gen.ifs): - key = evaluate_ast(dictcomp.key, new_state, tools) - val = evaluate_ast(dictcomp.value, new_state, tools) + set_value(gen.target, value, new_state, tools, custom_tools) + if all(evaluate_ast(if_clause, new_state, tools, custom_tools) for if_clause in gen.ifs): + key = evaluate_ast(dictcomp.key, new_state, tools, custom_tools) + val = evaluate_ast(dictcomp.value, new_state, tools, custom_tools) result[key] = val return result @@ -699,6 +701,7 @@ def evaluate_ast( expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable], + custom_tools: Dict[str, Callable], authorized_imports: List[str] = LIST_SAFE_MODULES, ): """ @@ -714,8 +717,9 @@ def evaluate_ast( A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation encounters assignements. tools (`Dict[str, Callable]`): - The functions that may be called during the evaluation. Any call to another function will fail with an - `InterpreterError`. + Functions that may be called during the evaluation. Trying to change one of these tools will raise an error. + custom_tools (`Dict[str, Callable]`): + Functions that may be called during the evaluation. These tools can be overwritten. authorized_imports (`List[str]`): The list of modules that can be imported by the code. By default, only a few safe modules are allowed. Add more at your own risk! @@ -729,105 +733,105 @@ def evaluate_ast( if isinstance(expression, ast.Assign): # Assignement -> we evaluate the assignment which should update the state # We return the variable assigned as it may be used to determine the final result. - return evaluate_assign(expression, state, tools) + return evaluate_assign(expression, state, tools, custom_tools) elif isinstance(expression, ast.AugAssign): - return evaluate_augassign(expression, state, tools) + return evaluate_augassign(expression, state, tools, custom_tools) elif isinstance(expression, ast.Call): # Function call -> we return the value of the function call - return evaluate_call(expression, state, tools) + return evaluate_call(expression, state, tools, custom_tools) elif isinstance(expression, ast.Constant): # Constant -> just return the value return expression.value elif isinstance(expression, ast.Tuple): - return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts) + return tuple(evaluate_ast(elt, state, tools, custom_tools) for elt in expression.elts) elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): - return evaluate_listcomp(expression, state, tools) + return evaluate_listcomp(expression, state, tools, custom_tools) elif isinstance(expression, ast.UnaryOp): - return evaluate_unaryop(expression, state, tools) + return evaluate_unaryop(expression, state, tools, custom_tools) elif isinstance(expression, ast.Starred): - return evaluate_ast(expression.value, state, tools) + return evaluate_ast(expression.value, state, tools, custom_tools) elif isinstance(expression, ast.BoolOp): # Boolean operation -> evaluate the operation - return evaluate_boolop(expression, state, tools) + return evaluate_boolop(expression, state, tools, custom_tools) elif isinstance(expression, ast.Break): raise BreakException() elif isinstance(expression, ast.Continue): raise ContinueException() elif isinstance(expression, ast.BinOp): # Binary operation -> execute operation - return evaluate_binop(expression, state, tools) + return evaluate_binop(expression, state, tools, custom_tools) elif isinstance(expression, ast.Compare): # Comparison -> evaluate the comparison - return evaluate_condition(expression, state, tools) + return evaluate_condition(expression, state, tools, custom_tools) elif isinstance(expression, ast.Lambda): - return evaluate_lambda(expression, state, tools) + return evaluate_lambda(expression, state, tools, custom_tools) elif isinstance(expression, ast.FunctionDef): - return evaluate_function_def(expression, state, tools) + return evaluate_function_def(expression, state, tools, custom_tools) elif isinstance(expression, ast.Dict): # Dict -> evaluate all keys and values - keys = [evaluate_ast(k, state, tools) for k in expression.keys] - values = [evaluate_ast(v, state, tools) for v in expression.values] + keys = [evaluate_ast(k, state, tools, custom_tools) for k in expression.keys] + values = [evaluate_ast(v, state, tools, custom_tools) for v in expression.values] return dict(zip(keys, values)) elif isinstance(expression, ast.Expr): # Expression -> evaluate the content - return evaluate_ast(expression.value, state, tools) + return evaluate_ast(expression.value, state, tools, custom_tools) elif isinstance(expression, ast.For): # For loop -> execute the loop - return evaluate_for(expression, state, tools) + return evaluate_for(expression, state, tools, custom_tools) elif isinstance(expression, ast.FormattedValue): # Formatted value (part of f-string) -> evaluate the content and return - return evaluate_ast(expression.value, state, tools) + return evaluate_ast(expression.value, state, tools, custom_tools) elif isinstance(expression, ast.If): # If -> execute the right branch - return evaluate_if(expression, state, tools) + return evaluate_if(expression, state, tools, custom_tools) elif hasattr(ast, "Index") and isinstance(expression, ast.Index): - return evaluate_ast(expression.value, state, tools) + return evaluate_ast(expression.value, state, tools, custom_tools) elif isinstance(expression, ast.JoinedStr): - return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values]) + return "".join([str(evaluate_ast(v, state, tools, custom_tools)) for v in expression.values]) elif isinstance(expression, ast.List): # List -> evaluate all elements - return [evaluate_ast(elt, state, tools) for elt in expression.elts] + return [evaluate_ast(elt, state, tools, custom_tools) for elt in expression.elts] elif isinstance(expression, ast.Name): # Name -> pick up the value in the state - return evaluate_name(expression, state, tools) + return evaluate_name(expression, state, tools, custom_tools) elif isinstance(expression, ast.Subscript): # Subscript -> return the value of the indexing - return evaluate_subscript(expression, state, tools) + return evaluate_subscript(expression, state, tools, custom_tools) elif isinstance(expression, ast.IfExp): - test_val = evaluate_ast(expression.test, state, tools) + test_val = evaluate_ast(expression.test, state, tools, custom_tools) if test_val: - return evaluate_ast(expression.body, state, tools) + return evaluate_ast(expression.body, state, tools, custom_tools) else: - return evaluate_ast(expression.orelse, state, tools) + return evaluate_ast(expression.orelse, state, tools, custom_tools) elif isinstance(expression, ast.Attribute): - value = evaluate_ast(expression.value, state, tools) + value = evaluate_ast(expression.value, state, tools, custom_tools) return getattr(value, expression.attr) elif isinstance(expression, ast.Slice): return slice( - evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None, - evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None, - evaluate_ast(expression.step, state, tools) if expression.step is not None else None, + evaluate_ast(expression.lower, state, tools, custom_tools) if expression.lower is not None else None, + evaluate_ast(expression.upper, state, tools, custom_tools) if expression.upper is not None else None, + evaluate_ast(expression.step, state, tools, custom_tools) if expression.step is not None else None, ) elif isinstance(expression, ast.DictComp): - return evaluate_dictcomp(expression, state, tools) + return evaluate_dictcomp(expression, state, tools, custom_tools) elif isinstance(expression, ast.While): - return evaluate_while(expression, state, tools) + return evaluate_while(expression, state, tools, custom_tools) elif isinstance(expression, (ast.Import, ast.ImportFrom)): return import_modules(expression, state, authorized_imports) elif isinstance(expression, ast.ClassDef): - return evaluate_class_def(expression, state, tools) + return evaluate_class_def(expression, state, tools, custom_tools) elif isinstance(expression, ast.Try): - return evaluate_try(expression, state, tools) + return evaluate_try(expression, state, tools, custom_tools) elif isinstance(expression, ast.Raise): - return evaluate_raise(expression, state, tools) + return evaluate_raise(expression, state, tools, custom_tools) elif isinstance(expression, ast.Assert): - return evaluate_assert(expression, state, tools) + return evaluate_assert(expression, state, tools, custom_tools) elif isinstance(expression, ast.With): - return evaluate_with(expression, state, tools) + return evaluate_with(expression, state, tools, custom_tools) elif isinstance(expression, ast.Set): - return {evaluate_ast(elt, state, tools) for elt in expression.elts} + return {evaluate_ast(elt, state, tools, custom_tools) for elt in expression.elts} elif isinstance(expression, ast.Return): - raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None) + raise ReturnException(evaluate_ast(expression.value, state, tools, custom_tools) if expression.value else None) else: # For now we refuse anything else. Let's add things as we need them. raise InterpreterError(f"{expression.__class__.__name__} is not supported.") @@ -836,6 +840,7 @@ def evaluate_ast( def evaluate_python_code( code: str, tools: Optional[Dict[str, Callable]] = None, + custom_tools: Optional[Dict[str, Callable]] = None, state: Optional[Dict[str, Any]] = None, authorized_imports: List[str] = LIST_SAFE_MODULES, ): @@ -864,6 +869,8 @@ def evaluate_python_code( state = {} if tools is None: tools = {} + if custom_tools is None: + custom_tools = {} result = None global PRINT_OUTPUTS PRINT_OUTPUTS = "" @@ -871,7 +878,7 @@ def evaluate_python_code( OPERATIONS_COUNT = 0 for node in expression.body: try: - result = evaluate_ast(node, state, tools, authorized_imports) + result = evaluate_ast(node, state, tools, custom_tools, authorized_imports) except InterpreterError as e: msg = "" if len(PRINT_OUTPUTS) > 0: diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 5bdaea1651b741..f47d0b0c35c3e0 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -223,7 +223,7 @@ def test_init_agent_with_different_toolsets(self): # check that add_base_tools will not interfere with existing tools with pytest.raises(KeyError) as e: agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True) - assert "python_interpreter already exists in the toolbox" in str(e) + assert "already exists in the toolbox" in str(e) # check that python_interpreter base tool does not get added to code agents agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 1afb2094f208be..680c9a696ef6e5 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -418,11 +418,11 @@ def test_additional_imports(self): code = "import numpy as np" evaluate_python_code(code, authorized_imports=["numpy"], state={}) - code = "import matplotlib.pyplot as plt" - evaluate_python_code(code, authorized_imports=["matplotlib.pyplot"], state={}) - evaluate_python_code(code, authorized_imports=["matplotlib"], state={}) + code = "import numpy.random as rd" + evaluate_python_code(code, authorized_imports=["numpy.random"], state={}) + evaluate_python_code(code, authorized_imports=["numpy"], state={}) with pytest.raises(InterpreterError): - evaluate_python_code(code, authorized_imports=["pyplot"], state={}) + evaluate_python_code(code, authorized_imports=["random"], state={}) def test_multiple_comparators(self): code = "0 <= -1 < 4 and 0 <= -5 < 4" @@ -455,7 +455,7 @@ def function(): print("2") function()""" state = {} - evaluate_python_code(code, {"print": print}, state) + evaluate_python_code(code, {"print": print}, state=state) assert state["print_outputs"] == "1\n2\n" def test_tuple_target_in_iterator(self):