diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 1ddfb6b4174777..0cf56335d3e66e 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -25,7 +25,19 @@ from .agent_types import AgentAudio, AgentImage, AgentText from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfEngine, MessageRole -from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT +from .prompts import ( + DEFAULT_CODE_SYSTEM_PROMPT, + DEFAULT_REACT_CODE_SYSTEM_PROMPT, + DEFAULT_REACT_JSON_SYSTEM_PROMPT, + PLAN_UPDATE_FINAL_PLAN_REDACTION, + SYSTEM_PROMPT_FACTS, + SYSTEM_PROMPT_FACTS_UPDATE, + SYSTEM_PROMPT_PLAN, + SYSTEM_PROMPT_PLAN_UPDATE, + USER_PROMPT_FACTS_UPDATE, + USER_PROMPT_PLAN, + USER_PROMPT_PLAN_UPDATE, +) from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .tools import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, @@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: def parse_code_blob(code_blob: str) -> str: try: - pattern = r"```(?:py|python)?\n(.*?)```" + pattern = r"```(?:py|python)?\n(.*?)\n```" match = re.search(pattern, code_blob, re.DOTALL) return match.group(1).strip() except Exception as e: raise ValueError( - f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}" + f""" +The code blob you used is invalid: due to the following error: {e} +This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance: +Thoughts: Your thoughts +Code: +```py +# Your python code here +```""" ) @@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: tool_call = parse_json_blob(json_blob) if "action" in tool_call and "action_input" in tool_call: return tool_call["action"], tool_call["action_input"] + elif "action" in tool_call: + return tool_call["action"], None else: raise ValueError( f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}" @@ -208,7 +229,7 @@ def add_tool(self, tool: Tool): The tool to add to the toolbox. """ if tool.name in self._tools: - raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.") + raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.") self._tools[tool.name] = tool def remove_tool(self, tool_name: str): @@ -359,12 +380,8 @@ def toolbox(self) -> Toolbox: """Get the toolbox currently available to the agent""" return self._toolbox - def initialize_for_run(self, task: str, **kwargs): + def initialize_for_run(self): self.token_count = 0 - self.task = task - if len(kwargs) > 0: - self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." - self.state = kwargs.copy() self.system_prompt = format_prompt_with_tools( self._toolbox, self.system_prompt_template, @@ -380,7 +397,7 @@ def initialize_for_run(self, task: str, **kwargs): self.logger.debug("System prompt is as follows:") self.logger.debug(self.system_prompt) - def write_inner_memory_from_logs(self) -> List[Dict[str, str]]: + def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: """ Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages that can be used as input to the LLM. @@ -390,43 +407,51 @@ def write_inner_memory_from_logs(self) -> List[Dict[str, str]]: "role": MessageRole.USER, "content": "Task: " + self.logs[0]["task"], } - memory = [prompt_message, task_message] + if summary_mode: + memory = [task_message] + else: + memory = [prompt_message, task_message] for i, step_log in enumerate(self.logs[1:]): - if "llm_output" in step_log: - thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"} + if "llm_output" in step_log and not summary_mode: + thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()} + memory.append(thought_message) + if "facts" in step_log: + thought_message = { + "role": MessageRole.ASSISTANT, + "content": "[FACTS LIST]:\n" + step_log["facts"].strip(), + } memory.append(thought_message) - if "error" in step_log: - message_content = ( - "Error: " - + str(step_log["error"]) - + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" - ) - elif "observation" in step_log: - message_content = f"Observation: {step_log['observation']}" - tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content} - memory.append(tool_response_message) - - if len(memory) % 3 == 0: - reminder_content = ( - "Reminder: you are working towards solving the following task: " + self.logs[0]["task"] - ) - reminder_content += "\nHere is a summary of your past tool calls and their results:" - for j in range(i + 1): - reminder_content += "\nStep " + str(j + 1) - if "tool_call" in self.logs[j]: - reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"]) - if self.memory_verbose: - if "observation" in self.logs[j]: - reminder_content += "\nObservation:" + str(self.logs[j]["observation"]) - if "error" in self.logs[j]: - reminder_content += "\nError:" + str(self.logs[j]["error"]) - memory.append( - { - "role": MessageRole.USER, - "content": reminder_content, - } - ) + if "plan" in step_log and not summary_mode: + thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()} + memory.append(thought_message) + + if "tool_call" in step_log and summary_mode: + tool_call_message = { + "role": MessageRole.ASSISTANT, + "content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(), + } + memory.append(tool_call_message) + + if "task" in step_log: + tool_call_message = { + "role": MessageRole.USER, + "content": "New task:\n" + step_log["task"], + } + memory.append(tool_call_message) + + if "error" in step_log or "observation" in step_log: + if "error" in step_log: + message_content = ( + f"[OUTPUT OF STEP {i}] Error: " + + str(step_log["error"]) + + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" + ) + elif "observation" in step_log: + message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}" + tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content} + memory.append(tool_response_message) + return memory def get_succinct_logs(self): @@ -459,7 +484,7 @@ def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: This method replaces arguments with the actual values from the state if they refer to state variables. Args: - tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox). + tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). arguments (Dict[str, str]): Arguments passed to the Tool. """ if tool_name not in self.toolbox.tools: @@ -559,7 +584,11 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): agent.run("What is the result of 2 power 3.7384?") ``` """ - self.initialize_for_run(task, **kwargs) + self.task = task + if len(kwargs) > 0: + self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." + self.state = kwargs.copy() + self.initialize_for_run() # Run LLM prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt} @@ -598,7 +627,8 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} output = self.python_evaluator( code_action, - available_tools, + static_tools=available_tools, + custom_tools={}, state=self.state, authorized_imports=self.authorized_imports, ) @@ -623,6 +653,7 @@ def __init__( llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + planning_interval: Optional[int] = None, **kwargs, ): super().__init__( @@ -632,6 +663,7 @@ def __init__( tool_description_template=tool_description_template, **kwargs, ) + self.planning_interval = planning_interval def provide_final_answer(self, task) -> str: """ @@ -655,11 +687,13 @@ def provide_final_answer(self, task) -> str: except Exception as e: return f"Error in generating final llm output: {e}." - def run(self, task: str, stream: bool = False, **kwargs): + def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): """ Runs the agent for the given task. + Args: task (`str`): The task to perform + Example: ```py from transformers.agents import ReactCodeAgent @@ -667,14 +701,23 @@ def run(self, task: str, stream: bool = False, **kwargs): agent.run("What is the result of 2 power 3.7384?") ``` """ + self.task = task + if len(kwargs) > 0: + self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." + self.state = kwargs.copy() + if reset: + self.initialize_for_run() + else: + self.logs.append({"task": task}) if stream: - return self.stream_run(task, **kwargs) + return self.stream_run(task) else: - return self.direct_run(task, **kwargs) - - def stream_run(self, task: str, **kwargs): - self.initialize_for_run(task, **kwargs) + return self.direct_run(task) + def stream_run(self, task: str): + """ + Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method. + """ final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: @@ -700,13 +743,16 @@ def stream_run(self, task: str, **kwargs): yield final_answer - def direct_run(self, task: str, **kwargs): - self.initialize_for_run(task, **kwargs) - + def direct_run(self, task: str): + """ + Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method. + """ final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: try: + if self.planning_interval is not None and iteration % self.planning_interval == 0: + self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) step_logs = self.step() if "final_answer" in step_logs: final_answer = step_logs["final_answer"] @@ -726,6 +772,96 @@ def direct_run(self, task: str, **kwargs): return final_answer + def planning_step(self, task, is_first_step: bool = False, iteration: int = None): + """ + Used periodically by the agent to plan the next steps to reach the objective. + + Args: + task (`str`): The task to perform + is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan. + iteration (`int`): The number of the current step, used as an indication for the LLM. + """ + if is_first_step: + message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS} + message_prompt_task = { + "role": MessageRole.USER, + "content": f"""Here is the task: +``` +{task} +``` +Now begin!""", + } + + answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task]) + + message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN} + message_user_prompt_plan = { + "role": MessageRole.USER, + "content": USER_PROMPT_PLAN.format( + task=task, + tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), + answer_facts=answer_facts, + ), + } + answer_plan = self.llm_engine( + [message_system_prompt_plan, message_user_prompt_plan], stop_sequences=[""] + ) + + final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: +``` +{answer_plan} +```""" + final_facts_redaction = f"""Here are the facts that I know so far: +``` +{answer_facts} +```""".strip() + self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) + self.logger.debug("===== Initial plan: =====") + self.logger.debug(final_plan_redaction) + else: # update plan + agent_memory = self.write_inner_memory_from_logs( + summary_mode=False + ) # This will not log the plan but will log facts + + # Redact updated facts + facts_update_system_prompt = { + "role": MessageRole.SYSTEM, + "content": SYSTEM_PROMPT_FACTS_UPDATE, + } + facts_update_message = { + "role": MessageRole.USER, + "content": USER_PROMPT_FACTS_UPDATE, + } + facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message]) + + # Redact updated plan + plan_update_message = { + "role": MessageRole.SYSTEM, + "content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task), + } + plan_update_message_user = { + "role": MessageRole.USER, + "content": USER_PROMPT_PLAN_UPDATE.format( + task=task, + tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), + facts_update=facts_update, + remaining_steps=(self.max_iterations - iteration), + ), + } + plan_update = self.llm_engine( + [plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=[""] + ) + + # Log final facts and plan + final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update) + final_facts_redaction = f"""Here is the updated list of the facts that I know: +``` +{facts_update} +```""" + self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) + self.logger.debug("===== Updated plan: =====") + self.logger.debug(final_plan_redaction) + class ReactJsonAgent(ReactAgent): """ @@ -740,6 +876,7 @@ def __init__( llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + planning_interval: Optional[int] = None, **kwargs, ): super().__init__( @@ -747,6 +884,7 @@ def __init__( llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, + planning_interval=planning_interval, **kwargs, ) @@ -792,11 +930,16 @@ def step(self): self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}") if tool_name == "final_answer": if isinstance(arguments, dict): - answer = arguments["answer"] + if "answer" in arguments: + answer = arguments["answer"] + if ( + isinstance(answer, str) and answer in self.state.keys() + ): # if the answer is a state variable, return the value + answer = self.state[answer] + else: + answer = arguments else: answer = arguments - if answer in self.state: # if the answer is a state variable, return the value - answer = self.state[answer] current_step_logs["final_answer"] = answer return current_step_logs else: @@ -835,6 +978,7 @@ def __init__( system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, additional_authorized_imports: Optional[List[str]] = None, + planning_interval: Optional[int] = None, **kwargs, ): super().__init__( @@ -842,6 +986,7 @@ def __init__( llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, + planning_interval=planning_interval, **kwargs, ) @@ -856,10 +1001,7 @@ def __init__( self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) - self.available_tools = { - **BASE_PYTHON_TOOLS.copy(), - **self.toolbox.tools, - } # This list can be augmented by the code agent creating some new functions + self.custom_tools = {} def step(self): """ @@ -911,7 +1053,11 @@ def step(self): try: result = self.python_evaluator( code_action, - tools=self.available_tools, + static_tools={ + **BASE_PYTHON_TOOLS.copy(), + **self.toolbox.tools, + }, + custom_tools=self.custom_tools, state=self.state, authorized_imports=self.authorized_imports, ) @@ -920,7 +1066,7 @@ def step(self): self.logger.log(32, information) current_step_logs["observation"] = information except Exception as e: - error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}" + error_msg = f"Code execution failed due to the following error:\n{str(e)}" if "'dict' object has no attribute 'read'" in str(e): error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string." raise AgentExecutionError(error_msg) diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 4a6f6dad3dca13..41909776726eca 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -173,7 +173,7 @@ def __init__(self, *args, authorized_imports=None, **kwargs): def forward(self, code): output = str( - evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports) + evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports) ) return output diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index 515ce3d439a499..4c3f9b56bcd7fb 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -365,7 +365,118 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. 7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables. 8. You can use imports in your code, but only from the following list of modules: <> -9. Don't give up! You're in charge of solving the task, not providing directions to solve it. +9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. +10. Don't give up! You're in charge of solving the task, not providing directions to solve it. Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ + +SYSTEM_PROMPT_FACTS = """Below I will present you a task. + +You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need. +To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it. +Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey: + +--- +### 1. Facts given in the task +List here the specific facts given in the task that could help you (there might be nothing here). + +### 2. Facts to look up +List here any facts that we may need to look up. +Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here. + +### 3. Facts to derive +List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation. + +Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings: +### 1. Facts given in the task +### 2. Facts to look up +### 3. Facts to derive +Do not add anything else.""" + +SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. + +Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. +This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. +Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. +After writing the final step of the plan, write the '\n' tag and stop there.""" + +USER_PROMPT_PLAN = """ +Here is your task: + +Task: +``` +{task} +``` + +Your plan can leverage any of these tools: +{tool_descriptions} + +List of facts that you know: +``` +{answer_facts} +``` + +Now begin! Write your plan below.""" + +SYSTEM_PROMPT_FACTS_UPDATE = """ +You are a world expert at gathering known and unknown facts based on a conversation. +Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these: +### 1. Facts given in the task +### 2. Facts that we have learned +### 3. Facts still to look up +### 4. Facts still to derive +Find the task and history below.""" + +USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts. +But since in your previous steps you may have learned useful new facts or invalidated some false ones. +Please update your list of facts based on the previous history, and provide these headings: +### 1. Facts given in the task +### 2. Facts that we have learned +### 3. Facts still to look up +### 4. Facts still to derive + +Now write your new list of facts below.""" + +SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. + +You have been given a task: +``` +{task} +``` + +Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task. +If the previous tries so far have met some success, you can make an updated plan based on these actions. +If you are stalled, you can make a completely new plan starting from scratch. +""" + +USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task: +``` +{task} +``` + +You have access to these tools: +{tool_descriptions} + +Here is the up to date list of facts that you know: +``` +{facts_update} +``` + +Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. +This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. +Beware that you have {remaining_steps} steps remaining. +Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. +After writing the final step of the plan, write the '\n' tag and stop there. + +Now write your new plan below.""" + +PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given: +``` +{task} +``` + +Here is my new/updated plan of action to solve the task: +``` +{plan_update} +```""" diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 1235bb95c3ae02..e641a8d0c17dae 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -18,8 +18,17 @@ import builtins import difflib from collections.abc import Mapping +from importlib import import_module from typing import Any, Callable, Dict, List, Optional +import numpy as np + +from ..utils import is_pandas_available + + +if is_pandas_available(): + import pandas as pd + class InterpreterError(ValueError): """ @@ -50,7 +59,8 @@ class InterpreterError(ValueError): "unicodedata", ] -PRINT_OUTPUTS = "" +PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000 +OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000 class BreakException(Exception): @@ -75,8 +85,8 @@ def get_iterable(obj): raise InterpreterError("Object is not iterable") -def evaluate_unaryop(expression, state, tools): - operand = evaluate_ast(expression.operand, state, tools) +def evaluate_unaryop(expression, state, static_tools, custom_tools): + operand = evaluate_ast(expression.operand, state, static_tools, custom_tools) if isinstance(expression.op, ast.USub): return -operand elif isinstance(expression.op, ast.UAdd): @@ -89,25 +99,25 @@ def evaluate_unaryop(expression, state, tools): raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") -def evaluate_lambda(lambda_expression, state, tools): +def evaluate_lambda(lambda_expression, state, static_tools, custom_tools): args = [arg.arg for arg in lambda_expression.args.args] def lambda_func(*values): new_state = state.copy() for arg, value in zip(args, values): new_state[arg] = value - return evaluate_ast(lambda_expression.body, new_state, tools) + return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools) return lambda_func -def evaluate_while(while_loop, state, tools): +def evaluate_while(while_loop, state, static_tools, custom_tools): max_iterations = 1000 iterations = 0 - while evaluate_ast(while_loop.test, state, tools): + while evaluate_ast(while_loop.test, state, static_tools, custom_tools): for node in while_loop.body: try: - evaluate_ast(node, state, tools) + evaluate_ast(node, state, static_tools, custom_tools) except BreakException: return None except ContinueException: @@ -118,11 +128,11 @@ def evaluate_while(while_loop, state, tools): return None -def create_function(func_def, state, tools): +def create_function(func_def, state, static_tools, custom_tools): def new_func(*args, **kwargs): func_state = state.copy() arg_names = [arg.arg for arg in func_def.args.args] - default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults] + default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults] # Apply default values defaults = dict(zip(arg_names[-len(default_values) :], default_values)) @@ -158,7 +168,7 @@ def new_func(*args, **kwargs): result = None try: for stmt in func_def.body: - result = evaluate_ast(stmt, func_state, tools) + result = evaluate_ast(stmt, func_state, static_tools, custom_tools) except ReturnException as e: result = e.value return result @@ -173,25 +183,25 @@ def create_class(class_name, class_bases, class_body): return type(class_name, tuple(class_bases), class_dict) -def evaluate_function_def(func_def, state, tools): - tools[func_def.name] = create_function(func_def, state, tools) - return tools[func_def.name] +def evaluate_function_def(func_def, state, static_tools, custom_tools): + custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools) + return custom_tools[func_def.name] -def evaluate_class_def(class_def, state, tools): +def evaluate_class_def(class_def, state, static_tools, custom_tools): class_name = class_def.name - bases = [evaluate_ast(base, state, tools) for base in class_def.bases] + bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases] class_dict = {} for stmt in class_def.body: if isinstance(stmt, ast.FunctionDef): - class_dict[stmt.name] = evaluate_function_def(stmt, state, tools) + class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools) elif isinstance(stmt, ast.Assign): for target in stmt.targets: if isinstance(target, ast.Name): - class_dict[target.id] = evaluate_ast(stmt.value, state, tools) + class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools) elif isinstance(target, ast.Attribute): - class_dict[target.attr] = evaluate_ast(stmt.value, state, tools) + class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools) else: raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") @@ -200,17 +210,17 @@ def evaluate_class_def(class_def, state, tools): return new_class -def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]): +def evaluate_augassign(expression, state, static_tools, custom_tools): # Helper function to get current value and set new value based on the target type def get_current_value(target): if isinstance(target, ast.Name): return state.get(target.id, 0) elif isinstance(target, ast.Subscript): - obj = evaluate_ast(target.value, state, tools) - key = evaluate_ast(target.slice, state, tools) + obj = evaluate_ast(target.value, state, static_tools, custom_tools) + key = evaluate_ast(target.slice, state, static_tools, custom_tools) return obj[key] elif isinstance(target, ast.Attribute): - obj = evaluate_ast(target.value, state, tools) + obj = evaluate_ast(target.value, state, static_tools, custom_tools) return getattr(obj, target.attr) elif isinstance(target, ast.Tuple): return tuple(get_current_value(elt) for elt in target.elts) @@ -220,7 +230,7 @@ def get_current_value(target): raise InterpreterError("AugAssign not supported for {type(target)} targets.") current_value = get_current_value(expression.target) - value_to_add = evaluate_ast(expression.value, state, tools) + value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools) # Determine the operation and apply it if isinstance(expression.op, ast.Add): @@ -256,28 +266,28 @@ def get_current_value(target): raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") # Update the state - set_value(expression.target, updated_value, state, tools) + set_value(expression.target, updated_value, state, static_tools, custom_tools) return updated_value -def evaluate_boolop(node, state, tools): +def evaluate_boolop(node, state, static_tools, custom_tools): if isinstance(node.op, ast.And): for value in node.values: - if not evaluate_ast(value, state, tools): + if not evaluate_ast(value, state, static_tools, custom_tools): return False return True elif isinstance(node.op, ast.Or): for value in node.values: - if evaluate_ast(value, state, tools): + if evaluate_ast(value, state, static_tools, custom_tools): return True return False -def evaluate_binop(binop, state, tools): +def evaluate_binop(binop, state, static_tools, custom_tools): # Recursively evaluate the left and right operands - left_val = evaluate_ast(binop.left, state, tools) - right_val = evaluate_ast(binop.right, state, tools) + left_val = evaluate_ast(binop.left, state, static_tools, custom_tools) + right_val = evaluate_ast(binop.right, state, static_tools, custom_tools) # Determine the operation based on the type of the operator in the BinOp if isinstance(binop.op, ast.Add): @@ -308,66 +318,92 @@ def evaluate_binop(binop, state, tools): raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") -def evaluate_assign(assign, state, tools): - result = evaluate_ast(assign.value, state, tools) +def evaluate_assign(assign, state, static_tools, custom_tools): + result = evaluate_ast(assign.value, state, static_tools, custom_tools) if len(assign.targets) == 1: target = assign.targets[0] - set_value(target, result, state, tools) + set_value(target, result, state, static_tools, custom_tools) else: if len(assign.targets) != len(result): raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") - for tgt, val in zip(assign.targets, result): - set_value(tgt, val, state, tools) + expanded_values = [] + for tgt in assign.targets: + if isinstance(tgt, ast.Starred): + expanded_values.extend(result) + else: + expanded_values.append(result) + for tgt, val in zip(assign.targets, expanded_values): + set_value(tgt, val, state, static_tools, custom_tools) return result -def set_value(target, value, state, tools): +def set_value(target, value, state, static_tools, custom_tools): if isinstance(target, ast.Name): - if target.id in tools: + if target.id in static_tools: raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") state[target.id] = value elif isinstance(target, ast.Tuple): if not isinstance(value, tuple): - raise InterpreterError("Cannot unpack non-tuple value") + if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): + value = tuple(value) + else: + raise InterpreterError("Cannot unpack non-tuple value") if len(target.elts) != len(value): raise InterpreterError("Cannot unpack tuple of wrong size") for i, elem in enumerate(target.elts): - set_value(elem, value[i], state, tools) + set_value(elem, value[i], state, static_tools, custom_tools) elif isinstance(target, ast.Subscript): - obj = evaluate_ast(target.value, state, tools) - key = evaluate_ast(target.slice, state, tools) + obj = evaluate_ast(target.value, state, static_tools, custom_tools) + key = evaluate_ast(target.slice, state, static_tools, custom_tools) obj[key] = value elif isinstance(target, ast.Attribute): - obj = evaluate_ast(target.value, state, tools) + obj = evaluate_ast(target.value, state, static_tools, custom_tools) setattr(obj, target.attr, value) -def evaluate_call(call, state, tools): +def evaluate_call(call, state, static_tools, custom_tools): if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): - raise InterpreterError( - f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})." - ) + raise InterpreterError(f"This is not a correct function: {call.func}).") if isinstance(call.func, ast.Attribute): - obj = evaluate_ast(call.func.value, state, tools) + obj = evaluate_ast(call.func.value, state, static_tools, custom_tools) func_name = call.func.attr if not hasattr(obj, func_name): raise InterpreterError(f"Object {obj} has no attribute {func_name}") func = getattr(obj, func_name) + elif isinstance(call.func, ast.Name): func_name = call.func.id if func_name in state: func = state[func_name] - elif func_name in tools: - func = tools[func_name] + elif func_name in static_tools: + func = static_tools[func_name] + elif func_name in custom_tools: + func = custom_tools[func_name] elif func_name in ERRORS: func = ERRORS[func_name] else: raise InterpreterError( - f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})." + f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})." ) - args = [evaluate_ast(arg, state, tools) for arg in call.args] - kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} + args = [] + for arg in call.args: + if isinstance(arg, ast.Starred): + args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools)) + else: + args.append(evaluate_ast(arg, state, static_tools, custom_tools)) + + args = [] + for arg in call.args: + if isinstance(arg, ast.Starred): + unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools) + if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)): + raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}") + args.extend(unpacked) + else: + args.append(evaluate_ast(arg, state, static_tools, custom_tools)) + + kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords} if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes # Instantiate the class using its constructor @@ -397,24 +433,31 @@ def evaluate_call(call, state, tools): output = " ".join(map(str, args)) global PRINT_OUTPUTS PRINT_OUTPUTS += output + "\n" + # cap the number of lines return output else: # Assume it's a callable object output = func(*args, **kwargs) return output -def evaluate_subscript(subscript, state, tools): - index = evaluate_ast(subscript.slice, state, tools) - value = evaluate_ast(subscript.value, state, tools) - if isinstance(index, slice): +def evaluate_subscript(subscript, state, static_tools, custom_tools): + index = evaluate_ast(subscript.slice, state, static_tools, custom_tools) + value = evaluate_ast(subscript.value, state, static_tools, custom_tools) + + if isinstance(value, pd.core.indexing._LocIndexer): + parent_object = value.obj + return parent_object.loc[index] + if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)): + return value[index] + elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy): + return value[index] + elif isinstance(index, slice): return value[index] elif isinstance(value, (list, tuple)): - # Ensure the index is within bounds if not (-len(value) <= index < len(value)): raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") return value[int(index)] elif isinstance(value, str): - # Ensure the index is within bounds if not (-len(value) <= index < len(value)): raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") return value[index] @@ -427,11 +470,11 @@ def evaluate_subscript(subscript, state, tools): raise InterpreterError(f"Could not index {value} with '{index}'.") -def evaluate_name(name, state, tools): +def evaluate_name(name, state, static_tools, custom_tools): if name.id in state: return state[name.id] - elif name.id in tools: - return tools[name.id] + elif name.id in static_tools: + return static_tools[name.id] elif name.id in ERRORS: return ERRORS[name.id] close_matches = difflib.get_close_matches(name.id, list(state.keys())) @@ -440,9 +483,9 @@ def evaluate_name(name, state, tools): raise InterpreterError(f"The variable `{name.id}` is not defined.") -def evaluate_condition(condition, state, tools): - left = evaluate_ast(condition.left, state, tools) - comparators = [evaluate_ast(c, state, tools) for c in condition.comparators] +def evaluate_condition(condition, state, static_tools, custom_tools): + left = evaluate_ast(condition.left, state, static_tools, custom_tools) + comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators] ops = [type(op) for op in condition.ops] result = True @@ -450,63 +493,61 @@ def evaluate_condition(condition, state, tools): for op, comparator in zip(ops, comparators): if op == ast.Eq: - result = result and (current_left == comparator) + current_result = current_left == comparator elif op == ast.NotEq: - result = result and (current_left != comparator) + current_result = current_left != comparator elif op == ast.Lt: - result = result and (current_left < comparator) + current_result = current_left < comparator elif op == ast.LtE: - result = result and (current_left <= comparator) + current_result = current_left <= comparator elif op == ast.Gt: - result = result and (current_left > comparator) + current_result = current_left > comparator elif op == ast.GtE: - result = result and (current_left >= comparator) + current_result = current_left >= comparator elif op == ast.Is: - result = result and (current_left is comparator) + current_result = current_left is comparator elif op == ast.IsNot: - result = result and (current_left is not comparator) + current_result = current_left is not comparator elif op == ast.In: - result = result and (current_left in comparator) + current_result = current_left in comparator elif op == ast.NotIn: - result = result and (current_left not in comparator) + current_result = current_left not in comparator else: raise InterpreterError(f"Operator not supported: {op}") + result = result & current_result current_left = comparator - if not result: + + if isinstance(result, bool) and not result: break - return result + return result if isinstance(result, (bool, pd.Series)) else result.all() -def evaluate_if(if_statement, state, tools): +def evaluate_if(if_statement, state, static_tools, custom_tools): result = None - test_result = evaluate_ast(if_statement.test, state, tools) + test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools) if test_result: for line in if_statement.body: - line_result = evaluate_ast(line, state, tools) + line_result = evaluate_ast(line, state, static_tools, custom_tools) if line_result is not None: result = line_result else: for line in if_statement.orelse: - line_result = evaluate_ast(line, state, tools) + line_result = evaluate_ast(line, state, static_tools, custom_tools) if line_result is not None: result = line_result return result -def evaluate_for(for_loop, state, tools): +def evaluate_for(for_loop, state, static_tools, custom_tools): result = None - iterator = evaluate_ast(for_loop.iter, state, tools) + iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools) for counter in iterator: - if isinstance(for_loop.target, ast.Tuple): - for i, elem in enumerate(for_loop.target.elts): - state[elem.id] = counter[i] - else: - state[for_loop.target.id] = counter + set_value(for_loop.target, counter, state, static_tools, custom_tools) for node in for_loop.body: try: - line_result = evaluate_ast(node, state, tools) + line_result = evaluate_ast(node, state, static_tools, custom_tools) if line_result is not None: result = line_result except BreakException: @@ -519,55 +560,60 @@ def evaluate_for(for_loop, state, tools): return result -def evaluate_listcomp(listcomp, state, tools): - result = [] - for generator in listcomp.generators: - iter_value = evaluate_ast(generator.iter, state, tools) +def evaluate_listcomp(listcomp, state, static_tools, custom_tools): + def inner_evaluate(generators, index, current_state): + if index >= len(generators): + return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)] + generator = generators[index] + iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools) + result = [] for value in iter_value: - new_state = state.copy() + new_state = current_state.copy() if isinstance(generator.target, ast.Tuple): for idx, elem in enumerate(generator.target.elts): new_state[elem.id] = value[idx] else: new_state[generator.target.id] = value - if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs): - result.append(evaluate_ast(listcomp.elt, new_state, tools)) - return result + if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs): + result.extend(inner_evaluate(generators, index + 1, new_state)) + return result + + return inner_evaluate(listcomp.generators, 0, state) -def evaluate_try(try_node, state, tools): +def evaluate_try(try_node, state, static_tools, custom_tools): try: for stmt in try_node.body: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, static_tools, custom_tools) except Exception as e: matched = False for handler in try_node.handlers: - if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)): + if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)): matched = True if handler.name: state[handler.name] = e for stmt in handler.body: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, static_tools, custom_tools) break if not matched: raise e else: if try_node.orelse: for stmt in try_node.orelse: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, static_tools, custom_tools) finally: if try_node.finalbody: for stmt in try_node.finalbody: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, static_tools, custom_tools) -def evaluate_raise(raise_node, state, tools): +def evaluate_raise(raise_node, state, static_tools, custom_tools): if raise_node.exc is not None: - exc = evaluate_ast(raise_node.exc, state, tools) + exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools) else: exc = None if raise_node.cause is not None: - cause = evaluate_ast(raise_node.cause, state, tools) + cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools) else: cause = None if exc is not None: @@ -579,11 +625,11 @@ def evaluate_raise(raise_node, state, tools): raise InterpreterError("Re-raise is not supported without an active exception") -def evaluate_assert(assert_node, state, tools): - test_result = evaluate_ast(assert_node.test, state, tools) +def evaluate_assert(assert_node, state, static_tools, custom_tools): + test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools) if not test_result: if assert_node.msg: - msg = evaluate_ast(assert_node.msg, state, tools) + msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools) raise AssertionError(msg) else: # Include the failing condition in the assertion message @@ -591,10 +637,10 @@ def evaluate_assert(assert_node, state, tools): raise AssertionError(f"Assertion failed: {test_code}") -def evaluate_with(with_node, state, tools): +def evaluate_with(with_node, state, static_tools, custom_tools): contexts = [] for item in with_node.items: - context_expr = evaluate_ast(item.context_expr, state, tools) + context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools) if item.optional_vars: state[item.optional_vars.id] = context_expr.__enter__() contexts.append(state[item.optional_vars.id]) @@ -604,7 +650,7 @@ def evaluate_with(with_node, state, tools): try: for stmt in with_node.body: - evaluate_ast(stmt, state, tools) + evaluate_ast(stmt, state, static_tools, custom_tools) except Exception as e: for context in reversed(contexts): context.__exit__(type(e), e, e.__traceback__) @@ -614,10 +660,51 @@ def evaluate_with(with_node, state, tools): context.__exit__(None, None, None) +def import_modules(expression, state, authorized_imports): + def check_module_authorized(module_name): + module_path = module_name.split(".") + module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] + return any(subpath in authorized_imports for subpath in module_subpaths) + + if isinstance(expression, ast.Import): + for alias in expression.names: + if check_module_authorized(alias.name): + module = import_module(alias.name) + state[alias.asname or alias.name] = module + else: + raise InterpreterError( + f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" + ) + return None + elif isinstance(expression, ast.ImportFrom): + if check_module_authorized(expression.module): + module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) + for alias in expression.names: + state[alias.asname or alias.name] = getattr(module, alias.name) + else: + raise InterpreterError(f"Import from {expression.module} is not allowed.") + return None + + +def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools): + result = {} + for gen in dictcomp.generators: + iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools) + for value in iter_value: + new_state = state.copy() + set_value(gen.target, value, new_state, static_tools, custom_tools) + if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs): + key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools) + val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools) + result[key] = val + return result + + def evaluate_ast( expression: ast.AST, state: Dict[str, Any], - tools: Dict[str, Callable], + static_tools: Dict[str, Callable], + custom_tools: Dict[str, Callable], authorized_imports: List[str] = LIST_SAFE_MODULES, ): """ @@ -632,146 +719,128 @@ def evaluate_ast( state (`Dict[str, Any]`): A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation encounters assignements. - tools (`Dict[str, Callable]`): - The functions that may be called during the evaluation. Any call to another function will fail with an - `InterpreterError`. + static_tools (`Dict[str, Callable]`): + Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error. + custom_tools (`Dict[str, Callable]`): + Functions that may be called during the evaluation. These static_tools can be overwritten. authorized_imports (`List[str]`): The list of modules that can be imported by the code. By default, only a few safe modules are allowed. Add more at your own risk! """ + global OPERATIONS_COUNT + if OPERATIONS_COUNT >= MAX_OPERATIONS: + raise InterpreterError( + f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations." + ) + OPERATIONS_COUNT += 1 if isinstance(expression, ast.Assign): # Assignement -> we evaluate the assignment which should update the state # We return the variable assigned as it may be used to determine the final result. - return evaluate_assign(expression, state, tools) + return evaluate_assign(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.AugAssign): - return evaluate_augassign(expression, state, tools) + return evaluate_augassign(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Call): # Function call -> we return the value of the function call - return evaluate_call(expression, state, tools) + return evaluate_call(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Constant): # Constant -> just return the value return expression.value elif isinstance(expression, ast.Tuple): - return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts) - elif isinstance(expression, ast.ListComp): - return evaluate_listcomp(expression, state, tools) + return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts) + elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): + return evaluate_listcomp(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.UnaryOp): - return evaluate_unaryop(expression, state, tools) + return evaluate_unaryop(expression, state, static_tools, custom_tools) + elif isinstance(expression, ast.Starred): + return evaluate_ast(expression.value, state, static_tools, custom_tools) elif isinstance(expression, ast.BoolOp): # Boolean operation -> evaluate the operation - return evaluate_boolop(expression, state, tools) + return evaluate_boolop(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Break): raise BreakException() elif isinstance(expression, ast.Continue): raise ContinueException() elif isinstance(expression, ast.BinOp): # Binary operation -> execute operation - return evaluate_binop(expression, state, tools) + return evaluate_binop(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Compare): # Comparison -> evaluate the comparison - return evaluate_condition(expression, state, tools) + return evaluate_condition(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Lambda): - return evaluate_lambda(expression, state, tools) + return evaluate_lambda(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.FunctionDef): - return evaluate_function_def(expression, state, tools) + return evaluate_function_def(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Dict): # Dict -> evaluate all keys and values - keys = [evaluate_ast(k, state, tools) for k in expression.keys] - values = [evaluate_ast(v, state, tools) for v in expression.values] + keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys] + values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values] return dict(zip(keys, values)) elif isinstance(expression, ast.Expr): # Expression -> evaluate the content - return evaluate_ast(expression.value, state, tools) + return evaluate_ast(expression.value, state, static_tools, custom_tools) elif isinstance(expression, ast.For): # For loop -> execute the loop - return evaluate_for(expression, state, tools) + return evaluate_for(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.FormattedValue): # Formatted value (part of f-string) -> evaluate the content and return - return evaluate_ast(expression.value, state, tools) + return evaluate_ast(expression.value, state, static_tools, custom_tools) elif isinstance(expression, ast.If): # If -> execute the right branch - return evaluate_if(expression, state, tools) + return evaluate_if(expression, state, static_tools, custom_tools) elif hasattr(ast, "Index") and isinstance(expression, ast.Index): - return evaluate_ast(expression.value, state, tools) + return evaluate_ast(expression.value, state, static_tools, custom_tools) elif isinstance(expression, ast.JoinedStr): - return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values]) + return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values]) elif isinstance(expression, ast.List): # List -> evaluate all elements - return [evaluate_ast(elt, state, tools) for elt in expression.elts] + return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts] elif isinstance(expression, ast.Name): # Name -> pick up the value in the state - return evaluate_name(expression, state, tools) + return evaluate_name(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Subscript): # Subscript -> return the value of the indexing - return evaluate_subscript(expression, state, tools) + return evaluate_subscript(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.IfExp): - test_val = evaluate_ast(expression.test, state, tools) + test_val = evaluate_ast(expression.test, state, static_tools, custom_tools) if test_val: - return evaluate_ast(expression.body, state, tools) + return evaluate_ast(expression.body, state, static_tools, custom_tools) else: - return evaluate_ast(expression.orelse, state, tools) + return evaluate_ast(expression.orelse, state, static_tools, custom_tools) elif isinstance(expression, ast.Attribute): - obj = evaluate_ast(expression.value, state, tools) - return getattr(obj, expression.attr) + value = evaluate_ast(expression.value, state, static_tools, custom_tools) + return getattr(value, expression.attr) elif isinstance(expression, ast.Slice): return slice( - evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None, - evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None, - evaluate_ast(expression.step, state, tools) if expression.step is not None else None, + evaluate_ast(expression.lower, state, static_tools, custom_tools) + if expression.lower is not None + else None, + evaluate_ast(expression.upper, state, static_tools, custom_tools) + if expression.upper is not None + else None, + evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None, ) - elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp): - result = [] - vars = {} - for generator in expression.generators: - var_name = generator.target.id - iter_value = evaluate_ast(generator.iter, state, tools) - for value in iter_value: - vars[var_name] = value - if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs): - elem = evaluate_ast(expression.elt, {**state, **vars}, tools) - result.append(elem) - return result elif isinstance(expression, ast.DictComp): - result = {} - for gen in expression.generators: - for container in get_iterable(evaluate_ast(gen.iter, state, tools)): - state[gen.target.id] = container - key = evaluate_ast(expression.key, state, tools) - value = evaluate_ast(expression.value, state, tools) - result[key] = value - return result - elif isinstance(expression, ast.Import): - for alias in expression.names: - if alias.name in authorized_imports: - module = __import__(alias.name) - state[alias.asname or alias.name] = module - else: - raise InterpreterError(f"Import of {alias.name} is not allowed.") - return None + return evaluate_dictcomp(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.While): - return evaluate_while(expression, state, tools) - elif isinstance(expression, ast.ImportFrom): - if expression.module in authorized_imports: - module = __import__(expression.module) - for alias in expression.names: - state[alias.asname or alias.name] = getattr(module, alias.name) - else: - raise InterpreterError(f"Import from {expression.module} is not allowed.") - return None + return evaluate_while(expression, state, static_tools, custom_tools) + elif isinstance(expression, (ast.Import, ast.ImportFrom)): + return import_modules(expression, state, authorized_imports) elif isinstance(expression, ast.ClassDef): - return evaluate_class_def(expression, state, tools) + return evaluate_class_def(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Try): - return evaluate_try(expression, state, tools) + return evaluate_try(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Raise): - return evaluate_raise(expression, state, tools) + return evaluate_raise(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Assert): - return evaluate_assert(expression, state, tools) + return evaluate_assert(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.With): - return evaluate_with(expression, state, tools) + return evaluate_with(expression, state, static_tools, custom_tools) elif isinstance(expression, ast.Set): - return {evaluate_ast(elt, state, tools) for elt in expression.elts} + return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts} elif isinstance(expression, ast.Return): - raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None) + raise ReturnException( + evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None + ) else: # For now we refuse anything else. Let's add things as we need them. raise InterpreterError(f"{expression.__class__.__name__} is not supported.") @@ -779,7 +848,8 @@ def evaluate_ast( def evaluate_python_code( code: str, - tools: Optional[Dict[str, Callable]] = None, + static_tools: Optional[Dict[str, Callable]] = None, + custom_tools: Optional[Dict[str, Callable]] = None, state: Optional[Dict[str, Any]] = None, authorized_imports: List[str] = LIST_SAFE_MODULES, ): @@ -792,9 +862,12 @@ def evaluate_python_code( Args: code (`str`): The code to evaluate. - tools (`Dict[str, Callable]`): - The functions that may be called during the evaluation. Any call to another function will fail with an - `InterpreterError`. + static_tools (`Dict[str, Callable]`): + The functions that may be called during the evaluation. + These tools cannot be overwritten in the code: any assignment to their name will raise an error. + custom_tools (`Dict[str, Callable]`): + The functions that may be called during the evaluation. + These tools can be overwritten in the code: any assignment to their name will overwrite them. state (`Dict[str, Any]`): A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be updated by this function to contain all variables as they are evaluated. @@ -806,20 +879,34 @@ def evaluate_python_code( raise SyntaxError(f"The code generated by the agent is not valid.\n{e}") if state is None: state = {} - if tools is None: - tools = {} + if static_tools is None: + static_tools = {} + if custom_tools is None: + custom_tools = {} result = None global PRINT_OUTPUTS PRINT_OUTPUTS = "" + global OPERATIONS_COUNT + OPERATIONS_COUNT = 0 for node in expression.body: try: - result = evaluate_ast(node, state, tools, authorized_imports) + result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) except InterpreterError as e: - msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" + msg = "" if len(PRINT_OUTPUTS) > 0: - msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n" + if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: + msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n" + else: + msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n" + msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" raise InterpreterError(msg) finally: - state["print_outputs"] = PRINT_OUTPUTS + if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: + state["print_outputs"] = PRINT_OUTPUTS + else: + state["print_outputs"] = ( + PRINT_OUTPUTS[:MAX_LEN_OUTPUT] + + f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._" + ) return result diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 5bdaea1651b741..f47d0b0c35c3e0 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -223,7 +223,7 @@ def test_init_agent_with_different_toolsets(self): # check that add_base_tools will not interfere with existing tools with pytest.raises(KeyError) as e: agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True) - assert "python_interpreter already exists in the toolbox" in str(e) + assert "already exists in the toolbox" in str(e) # check that python_interpreter base tool does not get added to code agents agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 8843a394b35313..8614302baae764 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -15,6 +15,7 @@ import unittest +import numpy as np import pytest from transformers import load_tool @@ -241,8 +242,41 @@ def test_tuples(self): code = """ digits, i = [1, 2, 3], 1 digits[i], digits[i + 1] = digits[i + 1], digits[i]""" + evaluate_python_code(code, {"range": range, "print": print, "int": int}, {}) + + code = """ +def calculate_isbn_10_check_digit(number): + total = sum((10 - i) * int(digit) for i, digit in enumerate(number)) + remainder = total % 11 + check_digit = 11 - remainder + if check_digit == 10: + return 'X' + elif check_digit == 11: + return '0' + else: + return str(check_digit) + +# Given 9-digit numbers +numbers = [ + "478225952", + "643485613", + "739394228", + "291726859", + "875262394", + "542617795", + "031810713", + "957007669", + "871467426" +] + +# Calculate check digits for each number +check_digits = [calculate_isbn_10_check_digit(number) for number in numbers] +print(check_digits) +""" state = {} - evaluate_python_code(code, {"range": range, "print": print, "int": int}, state) + evaluate_python_code( + code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state + ) def test_listcomp(self): code = "x = [i for i in range(3)]" @@ -273,6 +307,17 @@ def test_dictcomp(self): result = evaluate_python_code(code, {"range": range}, state={}) assert result == {0: 0, 1: 1, 2: 4} + code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" + result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) + assert result == {102: "b"} + + code = """ +shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')} +shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()} +""" + result = evaluate_python_code(code, {}, state={}) + assert result == {"A": ("a", "b"), "B": ("a", "b")} + def test_tuple_assignment(self): code = "a, b = 0, 1\nb" result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) @@ -341,7 +386,7 @@ def test_imports(self): result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result == "lose" - code = "import time\ntime.sleep(0.1)" + code = "import time, re\ntime.sleep(0.1)" result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result is None @@ -369,6 +414,23 @@ def test_imports(self): result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result == "LATIN CAPITAL LETTER A" + # Test submodules are handled properly, thus not raising error + code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) + + code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) + + def test_additional_imports(self): + code = "import numpy as np" + evaluate_python_code(code, authorized_imports=["numpy"], state={}) + + code = "import numpy.random as rd" + evaluate_python_code(code, authorized_imports=["numpy.random"], state={}) + evaluate_python_code(code, authorized_imports=["numpy"], state={}) + with pytest.raises(InterpreterError): + evaluate_python_code(code, authorized_imports=["random"], state={}) + def test_multiple_comparators(self): code = "0 <= -1 < 4 and 0 <= -5 < 4" result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) @@ -400,7 +462,7 @@ def function(): print("2") function()""" state = {} - evaluate_python_code(code, {"print": print}, state) + evaluate_python_code(code, {"print": print}, state=state) assert state["print_outputs"] == "1\n2\n" def test_tuple_target_in_iterator(self): @@ -612,7 +674,7 @@ def __exit__(self, exc_type, exc_value, traceback): """ state = {} tools = {} - evaluate_python_code(code, tools, state) + evaluate_python_code(code, tools, state=state) def test_default_arg_in_function(self): code = """ @@ -672,3 +734,94 @@ def returns_none(a): state = {} result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) assert result is None + + def test_nested_for_loop(self): + code = """ +all_res = [] +for i in range(10): + subres = [] + for j in range(i): + subres.append(j) + all_res.append(subres) + +out = [i for sublist in all_res for i in sublist] +out[:10] +""" + state = {} + result = evaluate_python_code(code, {"print": print, "range": range}, state=state) + assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] + + def test_pandas(self): + code = """ +import pandas as pd + +df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]}) + +df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce') + +parts_with_5_set_count = df[df['SetCount'] == 5.0] +parts_with_5_set_count[['Quantity', 'SetCount']].values[1] +""" + state = {} + result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"]) + assert np.array_equal(result, [-1, 5]) + + code = """ +import pandas as pd + +df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]}) +print("HH0") + +# Filter the DataFrame to get only the rows with outdated atomic numbers +filtered_df = df.loc[df['AtomicNumber'].isin([104])] +""" + result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) + assert np.array_equal(result.values[0], [104, 1]) + + code = """import pandas as pd +data = pd.DataFrame.from_dict([ + {"Pclass": 1, "Survived": 1}, + {"Pclass": 2, "Survived": 0}, + {"Pclass": 2, "Survived": 1} +]) +survival_rate_by_class = data.groupby('Pclass')['Survived'].mean() +""" + result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"]) + assert result.values[1] == 0.5 + + def test_starred(self): + code = """ +from math import radians, sin, cos, sqrt, atan2 + +def haversine(lat1, lon1, lat2, lon2): + R = 6371000 # Radius of the Earth in meters + lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2]) + dlat = lat2 - lat1 + dlon = lon2 - lon1 + a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2 + c = 2 * atan2(sqrt(a), sqrt(1 - a)) + distance = R * c + return distance + +coords_geneva = (46.1978, 6.1342) +coords_barcelona = (41.3869, 2.1660) + +distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) +""" + result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"]) + assert round(result, 1) == 622395.4 + + def test_for(self): + code = """ +shifts = { + "Worker A": ("6:45 pm", "8:00 pm"), + "Worker B": ("10:00 am", "11:45 am") +} + +shift_intervals = {} +for worker, (start, end) in shifts.items(): + shift_intervals[worker] = end +shift_intervals +""" + result = evaluate_python_code(code, {"print": print, "map": map}, state={}) + assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}