diff --git a/src/transformers/tools/agents.py b/src/transformers/tools/agents.py index 3e423ebb30556d..86adde93ef3001 100644 --- a/src/transformers/tools/agents.py +++ b/src/transformers/tools/agents.py @@ -315,6 +315,13 @@ def prepare_for_new_chat(self): self.chat_state = {} self.cached_tools = None + def clean_code_for_run(self, result): + """ + Override this method if you want to change the way the code is + cleaned for the `run` method. + """ + return clean_code_for_run(result) + def run(self, task, *, return_code=False, remote=False, **kwargs): """ Sends a request to the agent. @@ -339,7 +346,7 @@ def run(self, task, *, return_code=False, remote=False, **kwargs): """ prompt = self.format_prompt(task) result = self.generate_one(prompt, stop=["Task:"]) - explanation, code = clean_code_for_run(result) + explanation, code = self.clean_code_for_run(result) self.log(f"==Explanation from the agent==\n{explanation}") diff --git a/src/transformers/tools/evaluate_agent.py b/src/transformers/tools/evaluate_agent.py index 7d5cddf1c9d01f..e9d14fab56cd6b 100644 --- a/src/transformers/tools/evaluate_agent.py +++ b/src/transformers/tools/evaluate_agent.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat, clean_code_for_run +from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat from .python_interpreter import InterpretorError, evaluate @@ -554,7 +554,7 @@ def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False): problem = EVALUATION_TASKS[eval_idx[start_idx + idx]] if verbose: print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n") - explanation, code = clean_code_for_run(result) + explanation, code = agent.clean_code_for_run(result) # Evaluate agent answer and code answer agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)